diff --git a/.gitignore b/.gitignore index b73fbe52f..859d7d08f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +/.gopath cmd/dirdiff/dirdiff cmd/gentestdata/gentestdata cmd/restic/restic diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json new file mode 100644 index 000000000..32a7cad3c --- /dev/null +++ b/Godeps/Godeps.json @@ -0,0 +1,42 @@ +{ + "ImportPath": "github.com/restic/restic", + "GoVersion": "go1.4.2", + "Packages": [ + "./..." + ], + "Deps": [ + { + "ImportPath": "github.com/jessevdk/go-flags", + "Comment": "v1-293-g5e11878", + "Rev": "5e118789801496c93ba210d34ef1f2ce5a9173bd" + }, + { + "ImportPath": "github.com/juju/errors", + "Rev": "4567a5e69fd3130ca0d89f69478e7ac025b67452" + }, + { + "ImportPath": "github.com/kr/fs", + "Rev": "2788f0dbd16903de03cb8186e5c7d97b69ad387b" + }, + { + "ImportPath": "github.com/pkg/sftp", + "Rev": "506297c9013d2893d5c5daaa9155e7333a1c58de" + }, + { + "ImportPath": "golang.org/x/crypto/pbkdf2", + "Rev": "24ffb5feb3312a39054178a4b0a4554fc2201248" + }, + { + "ImportPath": "golang.org/x/crypto/poly1305", + "Rev": "24ffb5feb3312a39054178a4b0a4554fc2201248" + }, + { + "ImportPath": "golang.org/x/crypto/scrypt", + "Rev": "24ffb5feb3312a39054178a4b0a4554fc2201248" + }, + { + "ImportPath": "golang.org/x/crypto/ssh", + "Rev": "24ffb5feb3312a39054178a4b0a4554fc2201248" + } + ] +} diff --git a/Godeps/Readme b/Godeps/Readme new file mode 100644 index 000000000..4cdaa53d5 --- /dev/null +++ b/Godeps/Readme @@ -0,0 +1,5 @@ +This directory tree is generated automatically by godep. + +Please do not edit. + +See https://github.com/tools/godep for more information. diff --git a/Godeps/_workspace/.gitignore b/Godeps/_workspace/.gitignore new file mode 100644 index 000000000..f037d684e --- /dev/null +++ b/Godeps/_workspace/.gitignore @@ -0,0 +1,2 @@ +/pkg +/bin diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/.travis.yml b/Godeps/_workspace/src/github.com/jessevdk/go-flags/.travis.yml new file mode 100644 index 000000000..3165f0042 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/.travis.yml @@ -0,0 +1,35 @@ +language: go + +install: + # go-flags + - go get -d -v ./... + - go build -v ./... + + # linting + - go get golang.org/x/tools/cmd/vet + - go get github.com/golang/lint + - go install github.com/golang/lint/golint + + # code coverage + - go get golang.org/x/tools/cmd/cover + - go get github.com/onsi/ginkgo/ginkgo + - go get github.com/modocache/gover + - if [ "$TRAVIS_SECURE_ENV_VARS" = "true" ]; then go get github.com/mattn/goveralls; fi + +script: + # go-flags + - $(exit $(gofmt -l . | wc -l)) + - go test -v ./... + + # linting + - go tool vet -all=true -v=true . || true + - $(go env GOPATH | awk 'BEGIN{FS=":"} {print $1}')/bin/golint ./... + + # code coverage + - $(go env GOPATH | awk 'BEGIN{FS=":"} {print $1}')/bin/ginkgo -r -cover + - $(go env GOPATH | awk 'BEGIN{FS=":"} {print $1}')/bin/gover + - if [ "$TRAVIS_SECURE_ENV_VARS" = "true" ]; then $(go env GOPATH | awk 'BEGIN{FS=":"} {print $1}')/bin/goveralls -coverprofile=gover.coverprofile -service=travis-ci -repotoken $COVERALLS_TOKEN; fi + +env: + # coveralls.io + secure: "RCYbiB4P0RjQRIoUx/vG/AjP3mmYCbzOmr86DCww1Z88yNcy3hYr3Cq8rpPtYU5v0g7wTpu4adaKIcqRE9xknYGbqj3YWZiCoBP1/n4Z+9sHW3Dsd9D/GRGeHUus0laJUGARjWoCTvoEtOgTdGQDoX7mH+pUUY0FBltNYUdOiiU=" diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/LICENSE b/Godeps/_workspace/src/github.com/jessevdk/go-flags/LICENSE new file mode 100644 index 000000000..bcca0d521 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/LICENSE @@ -0,0 +1,26 @@ +Copyright (c) 2012 Jesse van den Kieboom. All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + * Neither the name of Google Inc. nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/README.md b/Godeps/_workspace/src/github.com/jessevdk/go-flags/README.md new file mode 100644 index 000000000..b6faef6d3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/README.md @@ -0,0 +1,131 @@ +go-flags: a go library for parsing command line arguments +========================================================= + +[![GoDoc](https://godoc.org/github.com/jessevdk/go-flags?status.png)](https://godoc.org/github.com/jessevdk/go-flags) [![Build Status](https://travis-ci.org/jessevdk/go-flags.svg?branch=master)](https://travis-ci.org/jessevdk/go-flags) [![Coverage Status](https://img.shields.io/coveralls/jessevdk/go-flags.svg)](https://coveralls.io/r/jessevdk/go-flags?branch=master) + +This library provides similar functionality to the builtin flag library of +go, but provides much more functionality and nicer formatting. From the +documentation: + +Package flags provides an extensive command line option parser. +The flags package is similar in functionality to the go builtin flag package +but provides more options and uses reflection to provide a convenient and +succinct way of specifying command line options. + +Supported features: +* Options with short names (-v) +* Options with long names (--verbose) +* Options with and without arguments (bool v.s. other type) +* Options with optional arguments and default values +* Multiple option groups each containing a set of options +* Generate and print well-formatted help message +* Passing remaining command line arguments after -- (optional) +* Ignoring unknown command line options (optional) +* Supports -I/usr/include -I=/usr/include -I /usr/include option argument specification +* Supports multiple short options -aux +* Supports all primitive go types (string, int{8..64}, uint{8..64}, float) +* Supports same option multiple times (can store in slice or last option counts) +* Supports maps +* Supports function callbacks +* Supports namespaces for (nested) option groups + +The flags package uses structs, reflection and struct field tags +to allow users to specify command line options. This results in very simple +and concise specification of your application options. For example: + + type Options struct { + Verbose []bool `short:"v" long:"verbose" description:"Show verbose debug information"` + } + +This specifies one option with a short name -v and a long name --verbose. +When either -v or --verbose is found on the command line, a 'true' value +will be appended to the Verbose field. e.g. when specifying -vvv, the +resulting value of Verbose will be {[true, true, true]}. + +Example: +-------- + var opts struct { + // Slice of bool will append 'true' each time the option + // is encountered (can be set multiple times, like -vvv) + Verbose []bool `short:"v" long:"verbose" description:"Show verbose debug information"` + + // Example of automatic marshalling to desired type (uint) + Offset uint `long:"offset" description:"Offset"` + + // Example of a callback, called each time the option is found. + Call func(string) `short:"c" description:"Call phone number"` + + // Example of a required flag + Name string `short:"n" long:"name" description:"A name" required:"true"` + + // Example of a value name + File string `short:"f" long:"file" description:"A file" value-name:"FILE"` + + // Example of a pointer + Ptr *int `short:"p" description:"A pointer to an integer"` + + // Example of a slice of strings + StringSlice []string `short:"s" description:"A slice of strings"` + + // Example of a slice of pointers + PtrSlice []*string `long:"ptrslice" description:"A slice of pointers to string"` + + // Example of a map + IntMap map[string]int `long:"intmap" description:"A map from string to int"` + } + + // Callback which will invoke callto: to call a number. + // Note that this works just on OS X (and probably only with + // Skype) but it shows the idea. + opts.Call = func(num string) { + cmd := exec.Command("open", "callto:"+num) + cmd.Start() + cmd.Process.Release() + } + + // Make some fake arguments to parse. + args := []string{ + "-vv", + "--offset=5", + "-n", "Me", + "-p", "3", + "-s", "hello", + "-s", "world", + "--ptrslice", "hello", + "--ptrslice", "world", + "--intmap", "a:1", + "--intmap", "b:5", + "arg1", + "arg2", + "arg3", + } + + // Parse flags from `args'. Note that here we use flags.ParseArgs for + // the sake of making a working example. Normally, you would simply use + // flags.Parse(&opts) which uses os.Args + args, err := flags.ParseArgs(&opts, args) + + if err != nil { + panic(err) + os.Exit(1) + } + + fmt.Printf("Verbosity: %v\n", opts.Verbose) + fmt.Printf("Offset: %d\n", opts.Offset) + fmt.Printf("Name: %s\n", opts.Name) + fmt.Printf("Ptr: %d\n", *opts.Ptr) + fmt.Printf("StringSlice: %v\n", opts.StringSlice) + fmt.Printf("PtrSlice: [%v %v]\n", *opts.PtrSlice[0], *opts.PtrSlice[1]) + fmt.Printf("IntMap: [a:%v b:%v]\n", opts.IntMap["a"], opts.IntMap["b"]) + fmt.Printf("Remaining args: %s\n", strings.Join(args, " ")) + + // Output: Verbosity: [true true] + // Offset: 5 + // Name: Me + // Ptr: 3 + // StringSlice: [hello world] + // PtrSlice: [hello world] + // IntMap: [a:1 b:5] + // Remaining args: arg1 arg2 arg3 + +More information can be found in the godocs: diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/arg.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/arg.go new file mode 100644 index 000000000..fd8db9c77 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/arg.go @@ -0,0 +1,21 @@ +package flags + +import ( + "reflect" +) + +// Arg represents a positional argument on the command line. +type Arg struct { + // The name of the positional argument (used in the help) + Name string + + // A description of the positional argument (used in the help) + Description string + + value reflect.Value + tag multiTag +} + +func (a *Arg) isRemaining() bool { + return a.value.Type().Kind() == reflect.Slice +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/arg_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/arg_test.go new file mode 100644 index 000000000..faea28093 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/arg_test.go @@ -0,0 +1,53 @@ +package flags + +import ( + "testing" +) + +func TestPositional(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + Positional struct { + Command int + Filename string + Rest []string + } `positional-args:"yes" required:"yes"` + }{} + + p := NewParser(&opts, Default) + ret, err := p.ParseArgs([]string{"10", "arg_test.go", "a", "b"}) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + return + } + + if opts.Positional.Command != 10 { + t.Fatalf("Expected opts.Positional.Command to be 10, but got %v", opts.Positional.Command) + } + + if opts.Positional.Filename != "arg_test.go" { + t.Fatalf("Expected opts.Positional.Filename to be \"arg_test.go\", but got %v", opts.Positional.Filename) + } + + assertStringArray(t, opts.Positional.Rest, []string{"a", "b"}) + assertStringArray(t, ret, []string{}) +} + +func TestPositionalRequired(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + Positional struct { + Command int + Filename string + Rest []string + } `positional-args:"yes" required:"yes"` + }{} + + p := NewParser(&opts, None) + _, err := p.ParseArgs([]string{"10"}) + + assertError(t, err, ErrRequired, "the required argument `Filename` was not provided") +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/assert_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/assert_test.go new file mode 100644 index 000000000..8e06636b6 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/assert_test.go @@ -0,0 +1,177 @@ +package flags + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "path" + "runtime" + "testing" +) + +func assertCallerInfo() (string, int) { + ptr := make([]uintptr, 15) + n := runtime.Callers(1, ptr) + + if n == 0 { + return "", 0 + } + + mef := runtime.FuncForPC(ptr[0]) + mefile, meline := mef.FileLine(ptr[0]) + + for i := 2; i < n; i++ { + f := runtime.FuncForPC(ptr[i]) + file, line := f.FileLine(ptr[i]) + + if file != mefile { + return file, line + } + } + + return mefile, meline +} + +func assertErrorf(t *testing.T, format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + + file, line := assertCallerInfo() + + t.Errorf("%s:%d: %s", path.Base(file), line, msg) +} + +func assertFatalf(t *testing.T, format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + + file, line := assertCallerInfo() + + t.Fatalf("%s:%d: %s", path.Base(file), line, msg) +} + +func assertString(t *testing.T, a string, b string) { + if a != b { + assertErrorf(t, "Expected %#v, but got %#v", b, a) + } +} + +func assertStringArray(t *testing.T, a []string, b []string) { + if len(a) != len(b) { + assertErrorf(t, "Expected %#v, but got %#v", b, a) + return + } + + for i, v := range a { + if b[i] != v { + assertErrorf(t, "Expected %#v, but got %#v", b, a) + return + } + } +} + +func assertBoolArray(t *testing.T, a []bool, b []bool) { + if len(a) != len(b) { + assertErrorf(t, "Expected %#v, but got %#v", b, a) + return + } + + for i, v := range a { + if b[i] != v { + assertErrorf(t, "Expected %#v, but got %#v", b, a) + return + } + } +} + +func assertParserSuccess(t *testing.T, data interface{}, args ...string) (*Parser, []string) { + parser := NewParser(data, Default&^PrintErrors) + ret, err := parser.ParseArgs(args) + + if err != nil { + t.Fatalf("Unexpected parse error: %s", err) + return nil, nil + } + + return parser, ret +} + +func assertParseSuccess(t *testing.T, data interface{}, args ...string) []string { + _, ret := assertParserSuccess(t, data, args...) + return ret +} + +func assertError(t *testing.T, err error, typ ErrorType, msg string) { + if err == nil { + assertFatalf(t, "Expected error: %s", msg) + return + } + + if e, ok := err.(*Error); !ok { + assertFatalf(t, "Expected Error type, but got %#v", err) + } else { + if e.Type != typ { + assertErrorf(t, "Expected error type {%s}, but got {%s}", typ, e.Type) + } + + if e.Message != msg { + assertErrorf(t, "Expected error message %#v, but got %#v", msg, e.Message) + } + } +} + +func assertParseFail(t *testing.T, typ ErrorType, msg string, data interface{}, args ...string) []string { + parser := NewParser(data, Default&^PrintErrors) + ret, err := parser.ParseArgs(args) + + assertError(t, err, typ, msg) + return ret +} + +func diff(a, b string) (string, error) { + atmp, err := ioutil.TempFile("", "help-diff") + + if err != nil { + return "", err + } + + btmp, err := ioutil.TempFile("", "help-diff") + + if err != nil { + return "", err + } + + if _, err := io.WriteString(atmp, a); err != nil { + return "", err + } + + if _, err := io.WriteString(btmp, b); err != nil { + return "", err + } + + ret, err := exec.Command("diff", "-u", "-d", "--label", "got", atmp.Name(), "--label", "expected", btmp.Name()).Output() + + os.Remove(atmp.Name()) + os.Remove(btmp.Name()) + + if err.Error() == "exit status 1" { + return string(ret), nil + } + + return string(ret), err +} + +func assertDiff(t *testing.T, actual, expected, msg string) { + if actual == expected { + return + } + + ret, err := diff(actual, expected) + + if err != nil { + assertErrorf(t, "Unexpected diff error: %s", err) + assertErrorf(t, "Unexpected %s, expected:\n\n%s\n\nbut got\n\n%s", msg, expected, actual) + } else { + assertErrorf(t, "Unexpected %s:\n\n%s", msg, ret) + } +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/check_crosscompile.sh b/Godeps/_workspace/src/github.com/jessevdk/go-flags/check_crosscompile.sh new file mode 100644 index 000000000..c494f6119 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/check_crosscompile.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +set -e + +echo '# linux arm7' +GOARM=7 GOARCH=arm GOOS=linux go build +echo '# linux arm5' +GOARM=5 GOARCH=arm GOOS=linux go build +echo '# windows 386' +GOARCH=386 GOOS=windows go build +echo '# windows amd64' +GOARCH=amd64 GOOS=windows go build +echo '# darwin' +GOARCH=amd64 GOOS=darwin go build +echo '# freebsd' +GOARCH=amd64 GOOS=freebsd go build diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/closest.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/closest.go new file mode 100644 index 000000000..3b518757c --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/closest.go @@ -0,0 +1,59 @@ +package flags + +func levenshtein(s string, t string) int { + if len(s) == 0 { + return len(t) + } + + if len(t) == 0 { + return len(s) + } + + dists := make([][]int, len(s)+1) + for i := range dists { + dists[i] = make([]int, len(t)+1) + dists[i][0] = i + } + + for j := range t { + dists[0][j] = j + } + + for i, sc := range s { + for j, tc := range t { + if sc == tc { + dists[i+1][j+1] = dists[i][j] + } else { + dists[i+1][j+1] = dists[i][j] + 1 + if dists[i+1][j] < dists[i+1][j+1] { + dists[i+1][j+1] = dists[i+1][j] + 1 + } + if dists[i][j+1] < dists[i+1][j+1] { + dists[i+1][j+1] = dists[i][j+1] + 1 + } + } + } + } + + return dists[len(s)][len(t)] +} + +func closestChoice(cmd string, choices []string) (string, int) { + if len(choices) == 0 { + return "", 0 + } + + mincmd := -1 + mindist := -1 + + for i, c := range choices { + l := levenshtein(cmd, c) + + if mincmd < 0 || l < mindist { + mindist = l + mincmd = i + } + } + + return choices[mincmd], mindist +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/command.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/command.go new file mode 100644 index 000000000..13332ae33 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/command.go @@ -0,0 +1,106 @@ +package flags + +// Command represents an application command. Commands can be added to the +// parser (which itself is a command) and are selected/executed when its name +// is specified on the command line. The Command type embeds a Group and +// therefore also carries a set of command specific options. +type Command struct { + // Embedded, see Group for more information + *Group + + // The name by which the command can be invoked + Name string + + // The active sub command (set by parsing) or nil + Active *Command + + // Whether subcommands are optional + SubcommandsOptional bool + + // Aliases for the command + Aliases []string + + // Whether positional arguments are required + ArgsRequired bool + + commands []*Command + hasBuiltinHelpGroup bool + args []*Arg +} + +// Commander is an interface which can be implemented by any command added in +// the options. When implemented, the Execute method will be called for the last +// specified (sub)command providing the remaining command line arguments. +type Commander interface { + // Execute will be called for the last active (sub)command. The + // args argument contains the remaining command line arguments. The + // error that Execute returns will be eventually passed out of the + // Parse method of the Parser. + Execute(args []string) error +} + +// Usage is an interface which can be implemented to show a custom usage string +// in the help message shown for a command. +type Usage interface { + // Usage is called for commands to allow customized printing of command + // usage in the generated help message. + Usage() string +} + +// AddCommand adds a new command to the parser with the given name and data. The +// data needs to be a pointer to a struct from which the fields indicate which +// options are in the command. The provided data can implement the Command and +// Usage interfaces. +func (c *Command) AddCommand(command string, shortDescription string, longDescription string, data interface{}) (*Command, error) { + cmd := newCommand(command, shortDescription, longDescription, data) + + cmd.parent = c + + if err := cmd.scan(); err != nil { + return nil, err + } + + c.commands = append(c.commands, cmd) + return cmd, nil +} + +// AddGroup adds a new group to the command with the given name and data. The +// data needs to be a pointer to a struct from which the fields indicate which +// options are in the group. +func (c *Command) AddGroup(shortDescription string, longDescription string, data interface{}) (*Group, error) { + group := newGroup(shortDescription, longDescription, data) + + group.parent = c + + if err := group.scanType(c.scanSubcommandHandler(group)); err != nil { + return nil, err + } + + c.groups = append(c.groups, group) + return group, nil +} + +// Commands returns a list of subcommands of this command. +func (c *Command) Commands() []*Command { + return c.commands +} + +// Find locates the subcommand with the given name and returns it. If no such +// command can be found Find will return nil. +func (c *Command) Find(name string) *Command { + for _, cc := range c.commands { + if cc.match(name) { + return cc + } + } + + return nil +} + +// Args returns a list of positional arguments associated with this command. +func (c *Command) Args() []*Arg { + ret := make([]*Arg, len(c.args)) + copy(ret, c.args) + + return ret +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/command_private.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/command_private.go new file mode 100644 index 000000000..5d30a8afe --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/command_private.go @@ -0,0 +1,250 @@ +package flags + +import ( + "reflect" + "sort" + "strings" + "unsafe" +) + +type lookup struct { + shortNames map[string]*Option + longNames map[string]*Option + + commands map[string]*Command +} + +func newCommand(name string, shortDescription string, longDescription string, data interface{}) *Command { + return &Command{ + Group: newGroup(shortDescription, longDescription, data), + Name: name, + } +} + +func (c *Command) scanSubcommandHandler(parentg *Group) scanHandler { + f := func(realval reflect.Value, sfield *reflect.StructField) (bool, error) { + mtag := newMultiTag(string(sfield.Tag)) + + if err := mtag.Parse(); err != nil { + return true, err + } + + positional := mtag.Get("positional-args") + + if len(positional) != 0 { + stype := realval.Type() + + for i := 0; i < stype.NumField(); i++ { + field := stype.Field(i) + + m := newMultiTag((string(field.Tag))) + + if err := m.Parse(); err != nil { + return true, err + } + + name := m.Get("positional-arg-name") + + if len(name) == 0 { + name = field.Name + } + + arg := &Arg{ + Name: name, + Description: m.Get("description"), + + value: realval.Field(i), + tag: m, + } + + c.args = append(c.args, arg) + + if len(mtag.Get("required")) != 0 { + c.ArgsRequired = true + } + } + + return true, nil + } + + subcommand := mtag.Get("command") + + if len(subcommand) != 0 { + ptrval := reflect.NewAt(realval.Type(), unsafe.Pointer(realval.UnsafeAddr())) + + shortDescription := mtag.Get("description") + longDescription := mtag.Get("long-description") + subcommandsOptional := mtag.Get("subcommands-optional") + aliases := mtag.GetMany("alias") + + subc, err := c.AddCommand(subcommand, shortDescription, longDescription, ptrval.Interface()) + + if err != nil { + return true, err + } + + if len(subcommandsOptional) > 0 { + subc.SubcommandsOptional = true + } + + if len(aliases) > 0 { + subc.Aliases = aliases + } + + return true, nil + } + + return parentg.scanSubGroupHandler(realval, sfield) + } + + return f +} + +func (c *Command) scan() error { + return c.scanType(c.scanSubcommandHandler(c.Group)) +} + +func (c *Command) eachCommand(f func(*Command), recurse bool) { + f(c) + + for _, cc := range c.commands { + if recurse { + cc.eachCommand(f, true) + } else { + f(cc) + } + } +} + +func (c *Command) eachActiveGroup(f func(cc *Command, g *Group)) { + c.eachGroup(func(g *Group) { + f(c, g) + }) + + if c.Active != nil { + c.Active.eachActiveGroup(f) + } +} + +func (c *Command) addHelpGroups(showHelp func() error) { + if !c.hasBuiltinHelpGroup { + c.addHelpGroup(showHelp) + c.hasBuiltinHelpGroup = true + } + + for _, cc := range c.commands { + cc.addHelpGroups(showHelp) + } +} + +func (c *Command) makeLookup() lookup { + ret := lookup{ + shortNames: make(map[string]*Option), + longNames: make(map[string]*Option), + commands: make(map[string]*Command), + } + + c.eachGroup(func(g *Group) { + for _, option := range g.options { + if option.ShortName != 0 { + ret.shortNames[string(option.ShortName)] = option + } + + if len(option.LongName) > 0 { + ret.longNames[option.LongNameWithNamespace()] = option + } + } + }) + + for _, subcommand := range c.commands { + ret.commands[subcommand.Name] = subcommand + + for _, a := range subcommand.Aliases { + ret.commands[a] = subcommand + } + } + + return ret +} + +func (c *Command) groupByName(name string) *Group { + if grp := c.Group.groupByName(name); grp != nil { + return grp + } + + for _, subc := range c.commands { + prefix := subc.Name + "." + + if strings.HasPrefix(name, prefix) { + if grp := subc.groupByName(name[len(prefix):]); grp != nil { + return grp + } + } else if name == subc.Name { + return subc.Group + } + } + + return nil +} + +type commandList []*Command + +func (c commandList) Less(i, j int) bool { + return c[i].Name < c[j].Name +} + +func (c commandList) Len() int { + return len(c) +} + +func (c commandList) Swap(i, j int) { + c[i], c[j] = c[j], c[i] +} + +func (c *Command) sortedCommands() []*Command { + ret := make(commandList, len(c.commands)) + copy(ret, c.commands) + + sort.Sort(ret) + return []*Command(ret) +} + +func (c *Command) match(name string) bool { + if c.Name == name { + return true + } + + for _, v := range c.Aliases { + if v == name { + return true + } + } + + return false +} + +func (c *Command) hasCliOptions() bool { + ret := false + + c.eachGroup(func(g *Group) { + if g.isBuiltinHelp { + return + } + + for _, opt := range g.options { + if opt.canCli() { + ret = true + } + } + }) + + return ret +} + +func (c *Command) fillParseState(s *parseState) { + s.positional = make([]*Arg, len(c.args)) + copy(s.positional, c.args) + + s.lookup = c.makeLookup() + s.command = c +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/command_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/command_test.go new file mode 100644 index 000000000..a093e1588 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/command_test.go @@ -0,0 +1,354 @@ +package flags + +import ( + "fmt" + "testing" +) + +func TestCommandInline(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + Command struct { + G bool `short:"g"` + } `command:"cmd"` + }{} + + p, ret := assertParserSuccess(t, &opts, "-v", "cmd", "-g") + + assertStringArray(t, ret, []string{}) + + if p.Active == nil { + t.Errorf("Expected active command") + } + + if !opts.Value { + t.Errorf("Expected Value to be true") + } + + if !opts.Command.G { + t.Errorf("Expected Command.G to be true") + } + + if p.Command.Find("cmd") != p.Active { + t.Errorf("Expected to find command `cmd' to be active") + } +} + +func TestCommandInlineMulti(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + C1 struct { + } `command:"c1"` + + C2 struct { + G bool `short:"g"` + } `command:"c2"` + }{} + + p, ret := assertParserSuccess(t, &opts, "-v", "c2", "-g") + + assertStringArray(t, ret, []string{}) + + if p.Active == nil { + t.Errorf("Expected active command") + } + + if !opts.Value { + t.Errorf("Expected Value to be true") + } + + if !opts.C2.G { + t.Errorf("Expected C2.G to be true") + } + + if p.Command.Find("c1") == nil { + t.Errorf("Expected to find command `c1'") + } + + if c2 := p.Command.Find("c2"); c2 == nil { + t.Errorf("Expected to find command `c2'") + } else if c2 != p.Active { + t.Errorf("Expected to find command `c2' to be active") + } +} + +func TestCommandFlagOrder1(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + Command struct { + G bool `short:"g"` + } `command:"cmd"` + }{} + + assertParseFail(t, ErrUnknownFlag, "unknown flag `g'", &opts, "-v", "-g", "cmd") +} + +func TestCommandFlagOrder2(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + Command struct { + G bool `short:"g"` + } `command:"cmd"` + }{} + + assertParseFail(t, ErrUnknownFlag, "unknown flag `v'", &opts, "cmd", "-v", "-g") +} + +func TestCommandEstimate(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + Cmd1 struct { + } `command:"remove"` + + Cmd2 struct { + } `command:"add"` + }{} + + p := NewParser(&opts, None) + _, err := p.ParseArgs([]string{}) + + assertError(t, err, ErrCommandRequired, "Please specify one command of: add or remove") +} + +func TestCommandEstimate2(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + Cmd1 struct { + } `command:"remove"` + + Cmd2 struct { + } `command:"add"` + }{} + + p := NewParser(&opts, None) + _, err := p.ParseArgs([]string{"rmive"}) + + assertError(t, err, ErrUnknownCommand, "Unknown command `rmive', did you mean `remove'?") +} + +type testCommand struct { + G bool `short:"g"` + Executed bool + EArgs []string +} + +func (c *testCommand) Execute(args []string) error { + c.Executed = true + c.EArgs = args + + return nil +} + +func TestCommandExecute(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + Command testCommand `command:"cmd"` + }{} + + assertParseSuccess(t, &opts, "-v", "cmd", "-g", "a", "b") + + if !opts.Value { + t.Errorf("Expected Value to be true") + } + + if !opts.Command.Executed { + t.Errorf("Did not execute command") + } + + if !opts.Command.G { + t.Errorf("Expected Command.C to be true") + } + + assertStringArray(t, opts.Command.EArgs, []string{"a", "b"}) +} + +func TestCommandClosest(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + Cmd1 struct { + } `command:"remove"` + + Cmd2 struct { + } `command:"add"` + }{} + + args := assertParseFail(t, ErrUnknownCommand, "Unknown command `addd', did you mean `add'?", &opts, "-v", "addd") + + assertStringArray(t, args, []string{"addd"}) +} + +func TestCommandAdd(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + }{} + + var cmd = struct { + G bool `short:"g"` + }{} + + p := NewParser(&opts, Default) + c, err := p.AddCommand("cmd", "", "", &cmd) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + return + } + + ret, err := p.ParseArgs([]string{"-v", "cmd", "-g", "rest"}) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + return + } + + assertStringArray(t, ret, []string{"rest"}) + + if !opts.Value { + t.Errorf("Expected Value to be true") + } + + if !cmd.G { + t.Errorf("Expected Command.G to be true") + } + + if p.Command.Find("cmd") != c { + t.Errorf("Expected to find command `cmd'") + } + + if p.Commands()[0] != c { + t.Errorf("Expected command %#v, but got %#v", c, p.Commands()[0]) + } + + if c.Options()[0].ShortName != 'g' { + t.Errorf("Expected short name `g' but got %v", c.Options()[0].ShortName) + } +} + +func TestCommandNestedInline(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + Command struct { + G bool `short:"g"` + + Nested struct { + N string `long:"n"` + } `command:"nested"` + } `command:"cmd"` + }{} + + p, ret := assertParserSuccess(t, &opts, "-v", "cmd", "-g", "nested", "--n", "n", "rest") + + assertStringArray(t, ret, []string{"rest"}) + + if !opts.Value { + t.Errorf("Expected Value to be true") + } + + if !opts.Command.G { + t.Errorf("Expected Command.G to be true") + } + + assertString(t, opts.Command.Nested.N, "n") + + if c := p.Command.Find("cmd"); c == nil { + t.Errorf("Expected to find command `cmd'") + } else { + if c != p.Active { + t.Errorf("Expected `cmd' to be the active parser command") + } + + if nested := c.Find("nested"); nested == nil { + t.Errorf("Expected to find command `nested'") + } else if nested != c.Active { + t.Errorf("Expected to find command `nested' to be the active `cmd' command") + } + } +} + +func TestRequiredOnCommand(t *testing.T) { + var opts = struct { + Value bool `short:"v" required:"true"` + + Command struct { + G bool `short:"g"` + } `command:"cmd"` + }{} + + assertParseFail(t, ErrRequired, fmt.Sprintf("the required flag `%cv' was not specified", defaultShortOptDelimiter), &opts, "cmd") +} + +func TestRequiredAllOnCommand(t *testing.T) { + var opts = struct { + Value bool `short:"v" required:"true"` + Missing bool `long:"missing" required:"true"` + + Command struct { + G bool `short:"g"` + } `command:"cmd"` + }{} + + assertParseFail(t, ErrRequired, fmt.Sprintf("the required flags `%smissing' and `%cv' were not specified", defaultLongOptDelimiter, defaultShortOptDelimiter), &opts, "cmd") +} + +func TestDefaultOnCommand(t *testing.T) { + var opts = struct { + Command struct { + G bool `short:"g" default:"true"` + } `command:"cmd"` + }{} + + assertParseSuccess(t, &opts, "cmd") + + if !opts.Command.G { + t.Errorf("Expected G to be true") + } +} + +func TestSubcommandsOptional(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + Cmd1 struct { + } `command:"remove"` + + Cmd2 struct { + } `command:"add"` + }{} + + p := NewParser(&opts, None) + p.SubcommandsOptional = true + + _, err := p.ParseArgs([]string{"-v"}) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + return + } + + if !opts.Value { + t.Errorf("Expected Value to be true") + } +} + +func TestCommandAlias(t *testing.T) { + var opts = struct { + Command struct { + G bool `short:"g" default:"true"` + } `command:"cmd" alias:"cm"` + }{} + + assertParseSuccess(t, &opts, "cm") + + if !opts.Command.G { + t.Errorf("Expected G to be true") + } +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/completion.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/completion.go new file mode 100644 index 000000000..d0adfe031 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/completion.go @@ -0,0 +1,304 @@ +package flags + +import ( + "fmt" + "path/filepath" + "reflect" + "sort" + "strings" + "unicode/utf8" +) + +// Completion is a type containing information of a completion. +type Completion struct { + // The completed item + Item string + + // A description of the completed item (optional) + Description string +} + +type completions []Completion + +func (c completions) Len() int { + return len(c) +} + +func (c completions) Less(i, j int) bool { + return c[i].Item < c[j].Item +} + +func (c completions) Swap(i, j int) { + c[i], c[j] = c[j], c[i] +} + +// Completer is an interface which can be implemented by types +// to provide custom command line argument completion. +type Completer interface { + // Complete receives a prefix representing a (partial) value + // for its type and should provide a list of possible valid + // completions. + Complete(match string) []Completion +} + +type completion struct { + parser *Parser + + ShowDescriptions bool +} + +// Filename is a string alias which provides filename completion. +type Filename string + +func completionsWithoutDescriptions(items []string) []Completion { + ret := make([]Completion, len(items)) + + for i, v := range items { + ret[i].Item = v + } + + return ret +} + +// Complete returns a list of existing files with the given +// prefix. +func (f *Filename) Complete(match string) []Completion { + ret, _ := filepath.Glob(match + "*") + return completionsWithoutDescriptions(ret) +} + +func (c *completion) skipPositional(s *parseState, n int) { + if n >= len(s.positional) { + s.positional = nil + } else { + s.positional = s.positional[n:] + } +} + +func (c *completion) completeOptionNames(names map[string]*Option, prefix string, match string) []Completion { + n := make([]Completion, 0, len(names)) + + for k, opt := range names { + if strings.HasPrefix(k, match) { + n = append(n, Completion{ + Item: prefix + k, + Description: opt.Description, + }) + } + } + + return n +} + +func (c *completion) completeLongNames(s *parseState, prefix string, match string) []Completion { + return c.completeOptionNames(s.lookup.longNames, prefix, match) +} + +func (c *completion) completeShortNames(s *parseState, prefix string, match string) []Completion { + if len(match) != 0 { + return []Completion{ + Completion{ + Item: prefix + match, + }, + } + } + + return c.completeOptionNames(s.lookup.shortNames, prefix, match) +} + +func (c *completion) completeCommands(s *parseState, match string) []Completion { + n := make([]Completion, 0, len(s.command.commands)) + + for _, cmd := range s.command.commands { + if cmd.data != c && strings.HasPrefix(cmd.Name, match) { + n = append(n, Completion{ + Item: cmd.Name, + Description: cmd.ShortDescription, + }) + } + } + + return n +} + +func (c *completion) completeValue(value reflect.Value, prefix string, match string) []Completion { + i := value.Interface() + + var ret []Completion + + if cmp, ok := i.(Completer); ok { + ret = cmp.Complete(match) + } else if value.CanAddr() { + if cmp, ok = value.Addr().Interface().(Completer); ok { + ret = cmp.Complete(match) + } + } + + for i, v := range ret { + ret[i].Item = prefix + v.Item + } + + return ret +} + +func (c *completion) completeArg(arg *Arg, prefix string, match string) []Completion { + if arg.isRemaining() { + // For remaining positional args (that are parsed into a slice), complete + // based on the element type. + return c.completeValue(reflect.New(arg.value.Type().Elem()), prefix, match) + } + + return c.completeValue(arg.value, prefix, match) +} + +func (c *completion) complete(args []string) []Completion { + if len(args) == 0 { + args = []string{""} + } + + s := &parseState{ + args: args, + } + + c.parser.fillParseState(s) + + var opt *Option + + for len(s.args) > 1 { + arg := s.pop() + + if (c.parser.Options&PassDoubleDash) != None && arg == "--" { + opt = nil + c.skipPositional(s, len(s.args)-1) + + break + } + + if argumentIsOption(arg) { + prefix, optname, islong := stripOptionPrefix(arg) + optname, _, argument := splitOption(prefix, optname, islong) + + if argument == nil { + var o *Option + canarg := true + + if islong { + o = s.lookup.longNames[optname] + } else { + for i, r := range optname { + sname := string(r) + o = s.lookup.shortNames[sname] + + if o == nil { + break + } + + if i == 0 && o.canArgument() && len(optname) != len(sname) { + canarg = false + break + } + } + } + + if o == nil && (c.parser.Options&PassAfterNonOption) != None { + opt = nil + c.skipPositional(s, len(s.args)-1) + + break + } else if o != nil && o.canArgument() && !o.OptionalArgument && canarg { + if len(s.args) > 1 { + s.pop() + } else { + opt = o + } + } + } + } else { + if len(s.positional) > 0 { + if !s.positional[0].isRemaining() { + // Don't advance beyond a remaining positional arg (because + // it consumes all subsequent args). + s.positional = s.positional[1:] + } + } else if cmd, ok := s.lookup.commands[arg]; ok { + cmd.fillParseState(s) + } + + opt = nil + } + } + + lastarg := s.args[len(s.args)-1] + var ret []Completion + + if opt != nil { + // Completion for the argument of 'opt' + ret = c.completeValue(opt.value, "", lastarg) + } else if argumentStartsOption(lastarg) { + // Complete the option + prefix, optname, islong := stripOptionPrefix(lastarg) + optname, split, argument := splitOption(prefix, optname, islong) + + if argument == nil && !islong { + rname, n := utf8.DecodeRuneInString(optname) + sname := string(rname) + + if opt := s.lookup.shortNames[sname]; opt != nil && opt.canArgument() { + ret = c.completeValue(opt.value, prefix+sname, optname[n:]) + } else { + ret = c.completeShortNames(s, prefix, optname) + } + } else if argument != nil { + if islong { + opt = s.lookup.longNames[optname] + } else { + opt = s.lookup.shortNames[optname] + } + + if opt != nil { + ret = c.completeValue(opt.value, prefix+optname+split, *argument) + } + } else if islong { + ret = c.completeLongNames(s, prefix, optname) + } else { + ret = c.completeShortNames(s, prefix, optname) + } + } else if len(s.positional) > 0 { + // Complete for positional argument + ret = c.completeArg(s.positional[0], "", lastarg) + } else if len(s.command.commands) > 0 { + // Complete for command + ret = c.completeCommands(s, lastarg) + } + + sort.Sort(completions(ret)) + return ret +} + +func (c *completion) execute(args []string) { + ret := c.complete(args) + + if c.ShowDescriptions && len(ret) > 1 { + maxl := 0 + + for _, v := range ret { + if len(v.Item) > maxl { + maxl = len(v.Item) + } + } + + for _, v := range ret { + fmt.Printf("%s", v.Item) + + if len(v.Description) > 0 { + fmt.Printf("%s # %s", strings.Repeat(" ", maxl-len(v.Item)), v.Description) + } + + fmt.Printf("\n") + } + } else { + for _, v := range ret { + fmt.Println(v.Item) + } + } +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/completion_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/completion_test.go new file mode 100644 index 000000000..2d5a97f59 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/completion_test.go @@ -0,0 +1,289 @@ +package flags + +import ( + "bytes" + "io" + "os" + "path" + "path/filepath" + "reflect" + "runtime" + "strings" + "testing" +) + +type TestComplete struct { +} + +func (t *TestComplete) Complete(match string) []Completion { + options := []string{ + "hello world", + "hello universe", + "hello multiverse", + } + + ret := make([]Completion, 0, len(options)) + + for _, o := range options { + if strings.HasPrefix(o, match) { + ret = append(ret, Completion{ + Item: o, + }) + } + } + + return ret +} + +var completionTestOptions struct { + Verbose bool `short:"v" long:"verbose" description:"Verbose messages"` + Debug bool `short:"d" long:"debug" description:"Enable debug"` + Version bool `long:"version" description:"Show version"` + Required bool `long:"required" required:"true" description:"This is required"` + + AddCommand struct { + Positional struct { + Filename Filename + } `positional-args:"yes"` + } `command:"add" description:"add an item"` + + AddMultiCommand struct { + Positional struct { + Filename []Filename + } `positional-args:"yes"` + } `command:"add-multi" description:"add multiple items"` + + RemoveCommand struct { + Other bool `short:"o"` + File Filename `short:"f" long:"filename"` + } `command:"rm" description:"remove an item"` + + RenameCommand struct { + Completed TestComplete `short:"c" long:"completed"` + } `command:"rename" description:"rename an item"` +} + +type completionTest struct { + Args []string + Completed []string + ShowDescriptions bool +} + +var completionTests []completionTest + +func init() { + _, sourcefile, _, _ := runtime.Caller(0) + completionTestSourcedir := filepath.Join(filepath.SplitList(path.Dir(sourcefile))...) + + completionTestFilename := []string{filepath.Join(completionTestSourcedir, "completion.go"), filepath.Join(completionTestSourcedir, "completion_test.go")} + + completionTests = []completionTest{ + { + // Short names + []string{"-"}, + []string{"-d", "-v"}, + false, + }, + + { + // Short names concatenated + []string{"-dv"}, + []string{"-dv"}, + false, + }, + + { + // Long names + []string{"--"}, + []string{"--debug", "--required", "--verbose", "--version"}, + false, + }, + + { + // Long names with descriptions + []string{"--"}, + []string{ + "--debug # Enable debug", + "--required # This is required", + "--verbose # Verbose messages", + "--version # Show version", + }, + true, + }, + + { + // Long names partial + []string{"--ver"}, + []string{"--verbose", "--version"}, + false, + }, + + { + // Commands + []string{""}, + []string{"add", "add-multi", "rename", "rm"}, + false, + }, + + { + // Commands with descriptions + []string{""}, + []string{ + "add # add an item", + "add-multi # add multiple items", + "rename # rename an item", + "rm # remove an item", + }, + true, + }, + + { + // Commands partial + []string{"r"}, + []string{"rename", "rm"}, + false, + }, + + { + // Positional filename + []string{"add", filepath.Join(completionTestSourcedir, "completion")}, + completionTestFilename, + false, + }, + + { + // Multiple positional filename (1 arg) + []string{"add-multi", filepath.Join(completionTestSourcedir, "completion")}, + completionTestFilename, + false, + }, + { + // Multiple positional filename (2 args) + []string{"add-multi", filepath.Join(completionTestSourcedir, "completion.go"), filepath.Join(completionTestSourcedir, "completion")}, + completionTestFilename, + false, + }, + { + // Multiple positional filename (3 args) + []string{"add-multi", filepath.Join(completionTestSourcedir, "completion.go"), filepath.Join(completionTestSourcedir, "completion.go"), filepath.Join(completionTestSourcedir, "completion")}, + completionTestFilename, + false, + }, + + { + // Flag filename + []string{"rm", "-f", path.Join(completionTestSourcedir, "completion")}, + completionTestFilename, + false, + }, + + { + // Flag short concat last filename + []string{"rm", "-of", path.Join(completionTestSourcedir, "completion")}, + completionTestFilename, + false, + }, + + { + // Flag concat filename + []string{"rm", "-f" + path.Join(completionTestSourcedir, "completion")}, + []string{"-f" + completionTestFilename[0], "-f" + completionTestFilename[1]}, + false, + }, + + { + // Flag equal concat filename + []string{"rm", "-f=" + path.Join(completionTestSourcedir, "completion")}, + []string{"-f=" + completionTestFilename[0], "-f=" + completionTestFilename[1]}, + false, + }, + + { + // Flag concat long filename + []string{"rm", "--filename=" + path.Join(completionTestSourcedir, "completion")}, + []string{"--filename=" + completionTestFilename[0], "--filename=" + completionTestFilename[1]}, + false, + }, + + { + // Flag long filename + []string{"rm", "--filename", path.Join(completionTestSourcedir, "completion")}, + completionTestFilename, + false, + }, + + { + // Custom completed + []string{"rename", "-c", "hello un"}, + []string{"hello universe"}, + false, + }, + } +} + +func TestCompletion(t *testing.T) { + p := NewParser(&completionTestOptions, Default) + c := &completion{parser: p} + + for _, test := range completionTests { + if test.ShowDescriptions { + continue + } + + ret := c.complete(test.Args) + items := make([]string, len(ret)) + + for i, v := range ret { + items[i] = v.Item + } + + if !reflect.DeepEqual(items, test.Completed) { + t.Errorf("Args: %#v, %#v\n Expected: %#v\n Got: %#v", test.Args, test.ShowDescriptions, test.Completed, items) + } + } +} + +func TestParserCompletion(t *testing.T) { + for _, test := range completionTests { + if test.ShowDescriptions { + os.Setenv("GO_FLAGS_COMPLETION", "verbose") + } else { + os.Setenv("GO_FLAGS_COMPLETION", "1") + } + + tmp := os.Stdout + + r, w, _ := os.Pipe() + os.Stdout = w + + out := make(chan string) + + go func() { + var buf bytes.Buffer + + io.Copy(&buf, r) + + out <- buf.String() + }() + + p := NewParser(&completionTestOptions, None) + + _, err := p.ParseArgs(test.Args) + + w.Close() + + os.Stdout = tmp + + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + got := strings.Split(strings.Trim(<-out, "\n"), "\n") + + if !reflect.DeepEqual(got, test.Completed) { + t.Errorf("Expected: %#v\nGot: %#v", test.Completed, got) + } + } + + os.Setenv("GO_FLAGS_COMPLETION", "") +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/convert.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/convert.go new file mode 100644 index 000000000..191b5f4cd --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/convert.go @@ -0,0 +1,357 @@ +// Copyright 2012 Jesse van den Kieboom. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flags + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +// Marshaler is the interface implemented by types that can marshal themselves +// to a string representation of the flag. +type Marshaler interface { + // MarshalFlag marshals a flag value to its string representation. + MarshalFlag() (string, error) +} + +// Unmarshaler is the interface implemented by types that can unmarshal a flag +// argument to themselves. The provided value is directly passed from the +// command line. +type Unmarshaler interface { + // UnmarshalFlag unmarshals a string value representation to the flag + // value (which therefore needs to be a pointer receiver). + UnmarshalFlag(value string) error +} + +func getBase(options multiTag, base int) (int, error) { + sbase := options.Get("base") + + var err error + var ivbase int64 + + if sbase != "" { + ivbase, err = strconv.ParseInt(sbase, 10, 32) + base = int(ivbase) + } + + return base, err +} + +func convertMarshal(val reflect.Value) (bool, string, error) { + // Check first for the Marshaler interface + if val.Type().NumMethod() > 0 && val.CanInterface() { + if marshaler, ok := val.Interface().(Marshaler); ok { + ret, err := marshaler.MarshalFlag() + return true, ret, err + } + } + + return false, "", nil +} + +func convertToString(val reflect.Value, options multiTag) (string, error) { + if ok, ret, err := convertMarshal(val); ok { + return ret, err + } + + tp := val.Type() + + // Support for time.Duration + if tp == reflect.TypeOf((*time.Duration)(nil)).Elem() { + stringer := val.Interface().(fmt.Stringer) + return stringer.String(), nil + } + + switch tp.Kind() { + case reflect.String: + return val.String(), nil + case reflect.Bool: + if val.Bool() { + return "true", nil + } + + return "false", nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + base, err := getBase(options, 10) + + if err != nil { + return "", err + } + + return strconv.FormatInt(val.Int(), base), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + base, err := getBase(options, 10) + + if err != nil { + return "", err + } + + return strconv.FormatUint(val.Uint(), base), nil + case reflect.Float32, reflect.Float64: + return strconv.FormatFloat(val.Float(), 'g', -1, tp.Bits()), nil + case reflect.Slice: + if val.Len() == 0 { + return "", nil + } + + ret := "[" + + for i := 0; i < val.Len(); i++ { + if i != 0 { + ret += ", " + } + + item, err := convertToString(val.Index(i), options) + + if err != nil { + return "", err + } + + ret += item + } + + return ret + "]", nil + case reflect.Map: + ret := "{" + + for i, key := range val.MapKeys() { + if i != 0 { + ret += ", " + } + + keyitem, err := convertToString(key, options) + + if err != nil { + return "", err + } + + item, err := convertToString(val.MapIndex(key), options) + + if err != nil { + return "", err + } + + ret += keyitem + ":" + item + } + + return ret + "}", nil + case reflect.Ptr: + return convertToString(reflect.Indirect(val), options) + case reflect.Interface: + if !val.IsNil() { + return convertToString(val.Elem(), options) + } + } + + return "", nil +} + +func convertUnmarshal(val string, retval reflect.Value) (bool, error) { + if retval.Type().NumMethod() > 0 && retval.CanInterface() { + if unmarshaler, ok := retval.Interface().(Unmarshaler); ok { + return true, unmarshaler.UnmarshalFlag(val) + } + } + + if retval.Type().Kind() != reflect.Ptr && retval.CanAddr() { + return convertUnmarshal(val, retval.Addr()) + } + + if retval.Type().Kind() == reflect.Interface && !retval.IsNil() { + return convertUnmarshal(val, retval.Elem()) + } + + return false, nil +} + +func convert(val string, retval reflect.Value, options multiTag) error { + if ok, err := convertUnmarshal(val, retval); ok { + return err + } + + tp := retval.Type() + + // Support for time.Duration + if tp == reflect.TypeOf((*time.Duration)(nil)).Elem() { + parsed, err := time.ParseDuration(val) + + if err != nil { + return err + } + + retval.SetInt(int64(parsed)) + return nil + } + + switch tp.Kind() { + case reflect.String: + retval.SetString(val) + case reflect.Bool: + if val == "" { + retval.SetBool(true) + } else { + b, err := strconv.ParseBool(val) + + if err != nil { + return err + } + + retval.SetBool(b) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + base, err := getBase(options, 10) + + if err != nil { + return err + } + + parsed, err := strconv.ParseInt(val, base, tp.Bits()) + + if err != nil { + return err + } + + retval.SetInt(parsed) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + base, err := getBase(options, 10) + + if err != nil { + return err + } + + parsed, err := strconv.ParseUint(val, base, tp.Bits()) + + if err != nil { + return err + } + + retval.SetUint(parsed) + case reflect.Float32, reflect.Float64: + parsed, err := strconv.ParseFloat(val, tp.Bits()) + + if err != nil { + return err + } + + retval.SetFloat(parsed) + case reflect.Slice: + elemtp := tp.Elem() + + elemvalptr := reflect.New(elemtp) + elemval := reflect.Indirect(elemvalptr) + + if err := convert(val, elemval, options); err != nil { + return err + } + + retval.Set(reflect.Append(retval, elemval)) + case reflect.Map: + parts := strings.SplitN(val, ":", 2) + + key := parts[0] + var value string + + if len(parts) == 2 { + value = parts[1] + } + + keytp := tp.Key() + keyval := reflect.New(keytp) + + if err := convert(key, keyval, options); err != nil { + return err + } + + valuetp := tp.Elem() + valueval := reflect.New(valuetp) + + if err := convert(value, valueval, options); err != nil { + return err + } + + if retval.IsNil() { + retval.Set(reflect.MakeMap(tp)) + } + + retval.SetMapIndex(reflect.Indirect(keyval), reflect.Indirect(valueval)) + case reflect.Ptr: + if retval.IsNil() { + retval.Set(reflect.New(retval.Type().Elem())) + } + + return convert(val, reflect.Indirect(retval), options) + case reflect.Interface: + if !retval.IsNil() { + return convert(val, retval.Elem(), options) + } + } + + return nil +} + +func isPrint(s string) bool { + for _, c := range s { + if !strconv.IsPrint(c) { + return false + } + } + + return true +} + +func quoteIfNeeded(s string) string { + if !isPrint(s) { + return strconv.Quote(s) + } + + return s +} + +func unquoteIfPossible(s string) (string, error) { + if len(s) == 0 || s[0] != '"' { + return s, nil + } + + return strconv.Unquote(s) +} + +func wrapText(s string, l int, prefix string) string { + // Basic text wrapping of s at spaces to fit in l + var ret string + + s = strings.TrimSpace(s) + + for len(s) > l { + // Try to split on space + suffix := "" + + pos := strings.LastIndex(s[:l], " ") + + if pos < 0 { + pos = l - 1 + suffix = "-\n" + } + + if len(ret) != 0 { + ret += "\n" + prefix + } + + ret += strings.TrimSpace(s[:pos]) + suffix + s = strings.TrimSpace(s[pos:]) + } + + if len(s) > 0 { + if len(ret) != 0 { + ret += "\n" + prefix + } + + return ret + s + } + + return ret +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/convert_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/convert_test.go new file mode 100644 index 000000000..0de0eea7a --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/convert_test.go @@ -0,0 +1,175 @@ +package flags + +import ( + "testing" + "time" +) + +func expectConvert(t *testing.T, o *Option, expected string) { + s, err := convertToString(o.value, o.tag) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + assertString(t, s, expected) +} + +func TestConvertToString(t *testing.T) { + d, _ := time.ParseDuration("1h2m4s") + + var opts = struct { + String string `long:"string"` + + Int int `long:"int"` + Int8 int8 `long:"int8"` + Int16 int16 `long:"int16"` + Int32 int32 `long:"int32"` + Int64 int64 `long:"int64"` + + Uint uint `long:"uint"` + Uint8 uint8 `long:"uint8"` + Uint16 uint16 `long:"uint16"` + Uint32 uint32 `long:"uint32"` + Uint64 uint64 `long:"uint64"` + + Float32 float32 `long:"float32"` + Float64 float64 `long:"float64"` + + Duration time.Duration `long:"duration"` + + Bool bool `long:"bool"` + + IntSlice []int `long:"int-slice"` + IntFloatMap map[int]float64 `long:"int-float-map"` + + PtrBool *bool `long:"ptr-bool"` + Interface interface{} `long:"interface"` + + Int32Base int32 `long:"int32-base" base:"16"` + Uint32Base uint32 `long:"uint32-base" base:"16"` + }{ + "string", + + -2, + -1, + 0, + 1, + 2, + + 1, + 2, + 3, + 4, + 5, + + 1.2, + -3.4, + + d, + true, + + []int{-3, 4, -2}, + map[int]float64{-2: 4.5}, + + new(bool), + float32(5.2), + + -5823, + 4232, + } + + p := NewNamedParser("test", Default) + grp, _ := p.AddGroup("test group", "", &opts) + + expects := []string{ + "string", + "-2", + "-1", + "0", + "1", + "2", + + "1", + "2", + "3", + "4", + "5", + + "1.2", + "-3.4", + + "1h2m4s", + "true", + + "[-3, 4, -2]", + "{-2:4.5}", + + "false", + "5.2", + + "-16bf", + "1088", + } + + for i, v := range grp.Options() { + expectConvert(t, v, expects[i]) + } +} + +func TestConvertToStringInvalidIntBase(t *testing.T) { + var opts = struct { + Int int `long:"int" base:"no"` + }{ + 2, + } + + p := NewNamedParser("test", Default) + grp, _ := p.AddGroup("test group", "", &opts) + o := grp.Options()[0] + + _, err := convertToString(o.value, o.tag) + + if err != nil { + err = newErrorf(ErrMarshal, "%v", err) + } + + assertError(t, err, ErrMarshal, "strconv.ParseInt: parsing \"no\": invalid syntax") +} + +func TestConvertToStringInvalidUintBase(t *testing.T) { + var opts = struct { + Uint uint `long:"uint" base:"no"` + }{ + 2, + } + + p := NewNamedParser("test", Default) + grp, _ := p.AddGroup("test group", "", &opts) + o := grp.Options()[0] + + _, err := convertToString(o.value, o.tag) + + if err != nil { + err = newErrorf(ErrMarshal, "%v", err) + } + + assertError(t, err, ErrMarshal, "strconv.ParseInt: parsing \"no\": invalid syntax") +} + +func TestWrapText(t *testing.T) { + s := "Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum." + + got := wrapText(s, 60, " ") + expected := `Lorem ipsum dolor sit amet, consectetur adipisicing elit, + sed do eiusmod tempor incididunt ut labore et dolore magna + aliqua. Ut enim ad minim veniam, quis nostrud exercitation + ullamco laboris nisi ut aliquip ex ea commodo consequat. + Duis aute irure dolor in reprehenderit in voluptate velit + esse cillum dolore eu fugiat nulla pariatur. Excepteur sint + occaecat cupidatat non proident, sunt in culpa qui officia + deserunt mollit anim id est laborum.` + + assertDiff(t, got, expected, "wrapped text") +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/error.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/error.go new file mode 100644 index 000000000..fce9d3121 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/error.go @@ -0,0 +1,123 @@ +package flags + +import ( + "fmt" +) + +// ErrorType represents the type of error. +type ErrorType uint + +const ( + // ErrUnknown indicates a generic error. + ErrUnknown ErrorType = iota + + // ErrExpectedArgument indicates that an argument was expected. + ErrExpectedArgument + + // ErrUnknownFlag indicates an unknown flag. + ErrUnknownFlag + + // ErrUnknownGroup indicates an unknown group. + ErrUnknownGroup + + // ErrMarshal indicates a marshalling error while converting values. + ErrMarshal + + // ErrHelp indicates that the built-in help was shown (the error + // contains the help message). + ErrHelp + + // ErrNoArgumentForBool indicates that an argument was given for a + // boolean flag (which don't not take any arguments). + ErrNoArgumentForBool + + // ErrRequired indicates that a required flag was not provided. + ErrRequired + + // ErrShortNameTooLong indicates that a short flag name was specified, + // longer than one character. + ErrShortNameTooLong + + // ErrDuplicatedFlag indicates that a short or long flag has been + // defined more than once + ErrDuplicatedFlag + + // ErrTag indicates an error while parsing flag tags. + ErrTag + + // ErrCommandRequired indicates that a command was required but not + // specified + ErrCommandRequired + + // ErrUnknownCommand indicates that an unknown command was specified. + ErrUnknownCommand +) + +func (e ErrorType) String() string { + switch e { + case ErrUnknown: + return "unknown" + case ErrExpectedArgument: + return "expected argument" + case ErrUnknownFlag: + return "unknown flag" + case ErrUnknownGroup: + return "unknown group" + case ErrMarshal: + return "marshal" + case ErrHelp: + return "help" + case ErrNoArgumentForBool: + return "no argument for bool" + case ErrRequired: + return "required" + case ErrShortNameTooLong: + return "short name too long" + case ErrDuplicatedFlag: + return "duplicated flag" + case ErrTag: + return "tag" + case ErrCommandRequired: + return "command required" + case ErrUnknownCommand: + return "unknown command" + } + + return "unrecognized error type" +} + +// Error represents a parser error. The error returned from Parse is of this +// type. The error contains both a Type and Message. +type Error struct { + // The type of error + Type ErrorType + + // The error message + Message string +} + +// Error returns the error's message +func (e *Error) Error() string { + return e.Message +} + +func newError(tp ErrorType, message string) *Error { + return &Error{ + Type: tp, + Message: message, + } +} + +func newErrorf(tp ErrorType, format string, args ...interface{}) *Error { + return newError(tp, fmt.Sprintf(format, args...)) +} + +func wrapError(err error) *Error { + ret, ok := err.(*Error) + + if !ok { + return newError(ErrUnknown, err.Error()) + } + + return ret +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/example_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/example_test.go new file mode 100644 index 000000000..f7be2bb14 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/example_test.go @@ -0,0 +1,110 @@ +// Example of use of the flags package. +package flags + +import ( + "fmt" + "os/exec" +) + +func Example() { + var opts struct { + // Slice of bool will append 'true' each time the option + // is encountered (can be set multiple times, like -vvv) + Verbose []bool `short:"v" long:"verbose" description:"Show verbose debug information"` + + // Example of automatic marshalling to desired type (uint) + Offset uint `long:"offset" description:"Offset"` + + // Example of a callback, called each time the option is found. + Call func(string) `short:"c" description:"Call phone number"` + + // Example of a required flag + Name string `short:"n" long:"name" description:"A name" required:"true"` + + // Example of a value name + File string `short:"f" long:"file" description:"A file" value-name:"FILE"` + + // Example of a pointer + Ptr *int `short:"p" description:"A pointer to an integer"` + + // Example of a slice of strings + StringSlice []string `short:"s" description:"A slice of strings"` + + // Example of a slice of pointers + PtrSlice []*string `long:"ptrslice" description:"A slice of pointers to string"` + + // Example of a map + IntMap map[string]int `long:"intmap" description:"A map from string to int"` + + // Example of a filename (useful for completion) + Filename Filename `long:"filename" description:"A filename"` + + // Example of positional arguments + Args struct { + Id string + Num int + Rest []string + } `positional-args:"yes" required:"yes"` + } + + // Callback which will invoke callto: to call a number. + // Note that this works just on OS X (and probably only with + // Skype) but it shows the idea. + opts.Call = func(num string) { + cmd := exec.Command("open", "callto:"+num) + cmd.Start() + cmd.Process.Release() + } + + // Make some fake arguments to parse. + args := []string{ + "-vv", + "--offset=5", + "-n", "Me", + "-p", "3", + "-s", "hello", + "-s", "world", + "--ptrslice", "hello", + "--ptrslice", "world", + "--intmap", "a:1", + "--intmap", "b:5", + "--filename", "hello.go", + "id", + "10", + "remaining1", + "remaining2", + } + + // Parse flags from `args'. Note that here we use flags.ParseArgs for + // the sake of making a working example. Normally, you would simply use + // flags.Parse(&opts) which uses os.Args + _, err := ParseArgs(&opts, args) + + if err != nil { + panic(err) + } + + fmt.Printf("Verbosity: %v\n", opts.Verbose) + fmt.Printf("Offset: %d\n", opts.Offset) + fmt.Printf("Name: %s\n", opts.Name) + fmt.Printf("Ptr: %d\n", *opts.Ptr) + fmt.Printf("StringSlice: %v\n", opts.StringSlice) + fmt.Printf("PtrSlice: [%v %v]\n", *opts.PtrSlice[0], *opts.PtrSlice[1]) + fmt.Printf("IntMap: [a:%v b:%v]\n", opts.IntMap["a"], opts.IntMap["b"]) + fmt.Printf("Filename: %v\n", opts.Filename) + fmt.Printf("Args.Id: %s\n", opts.Args.Id) + fmt.Printf("Args.Num: %d\n", opts.Args.Num) + fmt.Printf("Args.Rest: %v\n", opts.Args.Rest) + + // Output: Verbosity: [true true] + // Offset: 5 + // Name: Me + // Ptr: 3 + // StringSlice: [hello world] + // PtrSlice: [hello world] + // IntMap: [a:1 b:5] + // Filename: hello.go + // Args.Id: id + // Args.Num: 10 + // Args.Rest: [remaining1 remaining2] +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/examples/add.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/examples/add.go new file mode 100644 index 000000000..57d8f232b --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/examples/add.go @@ -0,0 +1,23 @@ +package main + +import ( + "fmt" +) + +type AddCommand struct { + All bool `short:"a" long:"all" description:"Add all files"` +} + +var addCommand AddCommand + +func (x *AddCommand) Execute(args []string) error { + fmt.Printf("Adding (all=%v): %#v\n", x.All, args) + return nil +} + +func init() { + parser.AddCommand("add", + "Add a file", + "The add command adds a file to the repository. Use -a to add all files.", + &addCommand) +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/examples/bash-completion b/Godeps/_workspace/src/github.com/jessevdk/go-flags/examples/bash-completion new file mode 100644 index 000000000..974f52ad4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/examples/bash-completion @@ -0,0 +1,9 @@ +_examples() { + args=("${COMP_WORDS[@]:1:$COMP_CWORD}") + + local IFS=$'\n' + COMPREPLY=($(GO_FLAGS_COMPLETION=1 ${COMP_WORDS[0]} "${args[@]}")) + return 1 +} + +complete -F _examples examples diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/examples/main.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/examples/main.go new file mode 100644 index 000000000..4a22be6e8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/examples/main.go @@ -0,0 +1,75 @@ +package main + +import ( + "errors" + "fmt" + "github.com/jessevdk/go-flags" + "os" + "strconv" + "strings" +) + +type EditorOptions struct { + Input flags.Filename `short:"i" long:"input" description:"Input file" default:"-"` + Output flags.Filename `short:"o" long:"output" description:"Output file" default:"-"` +} + +type Point struct { + X, Y int +} + +func (p *Point) UnmarshalFlag(value string) error { + parts := strings.Split(value, ",") + + if len(parts) != 2 { + return errors.New("expected two numbers separated by a ,") + } + + x, err := strconv.ParseInt(parts[0], 10, 32) + + if err != nil { + return err + } + + y, err := strconv.ParseInt(parts[1], 10, 32) + + if err != nil { + return err + } + + p.X = int(x) + p.Y = int(y) + + return nil +} + +func (p Point) MarshalFlag() (string, error) { + return fmt.Sprintf("%d,%d", p.X, p.Y), nil +} + +type Options struct { + // Example of verbosity with level + Verbose []bool `short:"v" long:"verbose" description:"Verbose output"` + + // Example of optional value + User string `short:"u" long:"user" description:"User name" optional:"yes" optional-value:"pancake"` + + // Example of map with multiple default values + Users map[string]string `long:"users" description:"User e-mail map" default:"system:system@example.org" default:"admin:admin@example.org"` + + // Example of option group + Editor EditorOptions `group:"Editor Options"` + + // Example of custom type Marshal/Unmarshal + Point Point `long:"point" description:"A x,y point" default:"1,2"` +} + +var options Options + +var parser = flags.NewParser(&options, flags.Default) + +func main() { + if _, err := parser.Parse(); err != nil { + os.Exit(1) + } +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/examples/rm.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/examples/rm.go new file mode 100644 index 000000000..c9c1dd03a --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/examples/rm.go @@ -0,0 +1,23 @@ +package main + +import ( + "fmt" +) + +type RmCommand struct { + Force bool `short:"f" long:"force" description:"Force removal of files"` +} + +var rmCommand RmCommand + +func (x *RmCommand) Execute(args []string) error { + fmt.Printf("Removing (force=%v): %#v\n", x.Force, args) + return nil +} + +func init() { + parser.AddCommand("rm", + "Remove a file", + "The rm command removes a file to the repository. Use -f to force removal of files.", + &rmCommand) +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/flags.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/flags.go new file mode 100644 index 000000000..37d331d0a --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/flags.go @@ -0,0 +1,242 @@ +// Copyright 2012 Jesse van den Kieboom. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package flags provides an extensive command line option parser. +The flags package is similar in functionality to the go built-in flag package +but provides more options and uses reflection to provide a convenient and +succinct way of specifying command line options. + + +Supported features + +The following features are supported in go-flags: + + Options with short names (-v) + Options with long names (--verbose) + Options with and without arguments (bool v.s. other type) + Options with optional arguments and default values + Option default values from ENVIRONMENT_VARIABLES, including slice and map values + Multiple option groups each containing a set of options + Generate and print well-formatted help message + Passing remaining command line arguments after -- (optional) + Ignoring unknown command line options (optional) + Supports -I/usr/include -I=/usr/include -I /usr/include option argument specification + Supports multiple short options -aux + Supports all primitive go types (string, int{8..64}, uint{8..64}, float) + Supports same option multiple times (can store in slice or last option counts) + Supports maps + Supports function callbacks + Supports namespaces for (nested) option groups + +Additional features specific to Windows: + Options with short names (/v) + Options with long names (/verbose) + Windows-style options with arguments use a colon as the delimiter + Modify generated help message with Windows-style / options + + +Basic usage + +The flags package uses structs, reflection and struct field tags +to allow users to specify command line options. This results in very simple +and concise specification of your application options. For example: + + type Options struct { + Verbose []bool `short:"v" long:"verbose" description:"Show verbose debug information"` + } + +This specifies one option with a short name -v and a long name --verbose. +When either -v or --verbose is found on the command line, a 'true' value +will be appended to the Verbose field. e.g. when specifying -vvv, the +resulting value of Verbose will be {[true, true, true]}. + +Slice options work exactly the same as primitive type options, except that +whenever the option is encountered, a value is appended to the slice. + +Map options from string to primitive type are also supported. On the command +line, you specify the value for such an option as key:value. For example + + type Options struct { + AuthorInfo string[string] `short:"a"` + } + +Then, the AuthorInfo map can be filled with something like +-a name:Jesse -a "surname:van den Kieboom". + +Finally, for full control over the conversion between command line argument +values and options, user defined types can choose to implement the Marshaler +and Unmarshaler interfaces. + + +Available field tags + +The following is a list of tags for struct fields supported by go-flags: + + short: the short name of the option (single character) + long: the long name of the option + required: whether an option is required to appear on the command + line. If a required option is not present, the parser will + return ErrRequired (optional) + description: the description of the option (optional) + long-description: the long description of the option. Currently only + displayed in generated man pages (optional) + no-flag: if non-empty this field is ignored as an option (optional) + + optional: whether an argument of the option is optional (optional) + optional-value: the value of an optional option when the option occurs + without an argument. This tag can be specified multiple + times in the case of maps or slices (optional) + default: the default value of an option. This tag can be specified + multiple times in the case of slices or maps (optional) + default-mask: when specified, this value will be displayed in the help + instead of the actual default value. This is useful + mostly for hiding otherwise sensitive information from + showing up in the help. If default-mask takes the special + value "-", then no default value will be shown at all + (optional) + env: the default value of the option is overridden from the + specified environment variable, if one has been defined. + (optional) + env-delim: the 'env' default value from environment is split into + multiple values with the given delimiter string, use with + slices and maps (optional) + value-name: the name of the argument value (to be shown in the help) + (optional) + + base: a base (radix) used to convert strings to integer values, the + default base is 10 (i.e. decimal) (optional) + + ini-name: the explicit ini option name (optional) + no-ini: if non-empty this field is ignored as an ini option + (optional) + + group: when specified on a struct field, makes the struct + field a separate group with the given name (optional) + namespace: when specified on a group struct field, the namespace + gets prepended to every option's long name and + subgroup's namespace of this group, separated by + the parser's namespace delimiter (optional) + command: when specified on a struct field, makes the struct + field a (sub)command with the given name (optional) + subcommands-optional: when specified on a command struct field, makes + any subcommands of that command optional (optional) + alias: when specified on a command struct field, adds the + specified name as an alias for the command. Can be + be specified multiple times to add more than one + alias (optional) + positional-args: when specified on a field with a struct type, + uses the fields of that struct to parse remaining + positional command line arguments into (in order + of the fields). If a field has a slice type, + then all remaining arguments will be added to it. + Positional arguments are optional by default, + unless the "required" tag is specified together + with the "positional-args" tag (optional) + positional-arg-name: used on a field in a positional argument struct; name + of the positional argument placeholder to be shown in + the help (optional) + + +Either the `short:` tag or the `long:` must be specified to make the field eligible as an +option. + + +Option groups + +Option groups are a simple way to semantically separate your options. All +options in a particular group are shown together in the help under the name +of the group. Namespaces can be used to specify option long names more +precisely and emphasize the options affiliation to their group. + +There are currently three ways to specify option groups. + + 1. Use NewNamedParser specifying the various option groups. + 2. Use AddGroup to add a group to an existing parser. + 3. Add a struct field to the top-level options annotated with the + group:"group-name" tag. + + + +Commands + +The flags package also has basic support for commands. Commands are often +used in monolithic applications that support various commands or actions. +Take git for example, all of the add, commit, checkout, etc. are called +commands. Using commands you can easily separate multiple functions of your +application. + +There are currently two ways to specify a command. + + 1. Use AddCommand on an existing parser. + 2. Add a struct field to your options struct annotated with the + command:"command-name" tag. + +The most common, idiomatic way to implement commands is to define a global +parser instance and implement each command in a separate file. These +command files should define a go init function which calls AddCommand on +the global parser. + +When parsing ends and there is an active command and that command implements +the Commander interface, then its Execute method will be run with the +remaining command line arguments. + +Command structs can have options which become valid to parse after the +command has been specified on the command line. It is currently not valid +to specify options from the parent level of the command after the command +name has occurred. Thus, given a top-level option "-v" and a command "add": + + Valid: ./app -v add + Invalid: ./app add -v + + +Completion + +go-flags has builtin support to provide bash completion of flags, commands +and argument values. To use completion, the binary which uses go-flags +can be invoked in a special environment to list completion of the current +command line argument. It should be noted that this `executes` your application, +and it is up to the user to make sure there are no negative side effects (for +example from init functions). + +Setting the environment variable `GO_FLAGS_COMPLETION=1` enables completion +by replacing the argument parsing routine with the completion routine which +outputs completions for the passed arguments. The basic invocation to +complete a set of arguments is therefore: + + GO_FLAGS_COMPLETION=1 ./completion-example arg1 arg2 arg3 + +where `completion-example` is the binary, `arg1` and `arg2` are +the current arguments, and `arg3` (the last argument) is the argument +to be completed. If the GO_FLAGS_COMPLETION is set to "verbose", then +descriptions of possible completion items will also be shown, if there +are more than 1 completion items. + +To use this with bash completion, a simple file can be written which +calls the binary which supports go-flags completion: + + _completion_example() { + # All arguments except the first one + args=("${COMP_WORDS[@]:1:$COMP_CWORD}") + + # Only split on newlines + local IFS=$'\n' + + # Call completion (note that the first element of COMP_WORDS is + # the executable itself) + COMPREPLY=($(GO_FLAGS_COMPLETION=1 ${COMP_WORDS[0]} "${args[@]}")) + return 0 + } + + complete -F _completion_example completion-example + +Completion requires the parser option PassDoubleDash and is therefore enforced if the environment variable GO_FLAGS_COMPLETION is set. + +Customized completion for argument values is supported by implementing +the flags.Completer interface for the argument value type. An example +of a type which does so is the flags.Filename type, an alias of string +allowing simple filename completion. A slice or array argument value +whose element type implements flags.Completer will also be completed. +*/ +package flags diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/group.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/group.go new file mode 100644 index 000000000..8b609a3af --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/group.go @@ -0,0 +1,91 @@ +// Copyright 2012 Jesse van den Kieboom. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flags + +import ( + "errors" + "strings" +) + +// ErrNotPointerToStruct indicates that a provided data container is not +// a pointer to a struct. Only pointers to structs are valid data containers +// for options. +var ErrNotPointerToStruct = errors.New("provided data is not a pointer to struct") + +// Group represents an option group. Option groups can be used to logically +// group options together under a description. Groups are only used to provide +// more structure to options both for the user (as displayed in the help message) +// and for you, since groups can be nested. +type Group struct { + // A short description of the group. The + // short description is primarily used in the built-in generated help + // message + ShortDescription string + + // A long description of the group. The long + // description is primarily used to present information on commands + // (Command embeds Group) in the built-in generated help and man pages. + LongDescription string + + // The namespace of the group + Namespace string + + // The parent of the group or nil if it has no parent + parent interface{} + + // All the options in the group + options []*Option + + // All the subgroups + groups []*Group + + // Whether the group represents the built-in help group + isBuiltinHelp bool + + data interface{} +} + +// AddGroup adds a new group to the command with the given name and data. The +// data needs to be a pointer to a struct from which the fields indicate which +// options are in the group. +func (g *Group) AddGroup(shortDescription string, longDescription string, data interface{}) (*Group, error) { + group := newGroup(shortDescription, longDescription, data) + + group.parent = g + + if err := group.scan(); err != nil { + return nil, err + } + + g.groups = append(g.groups, group) + return group, nil +} + +// Groups returns the list of groups embedded in this group. +func (g *Group) Groups() []*Group { + return g.groups +} + +// Options returns the list of options in this group. +func (g *Group) Options() []*Option { + return g.options +} + +// Find locates the subgroup with the given short description and returns it. +// If no such group can be found Find will return nil. Note that the description +// is matched case insensitively. +func (g *Group) Find(shortDescription string) *Group { + lshortDescription := strings.ToLower(shortDescription) + + var ret *Group + + g.eachGroup(func(gg *Group) { + if gg != g && strings.ToLower(gg.ShortDescription) == lshortDescription { + ret = gg + } + }) + + return ret +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/group_private.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/group_private.go new file mode 100644 index 000000000..15251ce39 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/group_private.go @@ -0,0 +1,254 @@ +package flags + +import ( + "reflect" + "unicode/utf8" + "unsafe" +) + +type scanHandler func(reflect.Value, *reflect.StructField) (bool, error) + +func newGroup(shortDescription string, longDescription string, data interface{}) *Group { + return &Group{ + ShortDescription: shortDescription, + LongDescription: longDescription, + + data: data, + } +} + +func (g *Group) optionByName(name string, namematch func(*Option, string) bool) *Option { + prio := 0 + var retopt *Option + + for _, opt := range g.options { + if namematch != nil && namematch(opt, name) && prio < 4 { + retopt = opt + prio = 4 + } + + if name == opt.field.Name && prio < 3 { + retopt = opt + prio = 3 + } + + if name == opt.LongNameWithNamespace() && prio < 2 { + retopt = opt + prio = 2 + } + + if opt.ShortName != 0 && name == string(opt.ShortName) && prio < 1 { + retopt = opt + prio = 1 + } + } + + return retopt +} + +func (g *Group) eachGroup(f func(*Group)) { + f(g) + + for _, gg := range g.groups { + gg.eachGroup(f) + } +} + +func (g *Group) scanStruct(realval reflect.Value, sfield *reflect.StructField, handler scanHandler) error { + stype := realval.Type() + + if sfield != nil { + if ok, err := handler(realval, sfield); err != nil { + return err + } else if ok { + return nil + } + } + + for i := 0; i < stype.NumField(); i++ { + field := stype.Field(i) + + // PkgName is set only for non-exported fields, which we ignore + if field.PkgPath != "" { + continue + } + + mtag := newMultiTag(string(field.Tag)) + + if err := mtag.Parse(); err != nil { + return err + } + + // Skip fields with the no-flag tag + if mtag.Get("no-flag") != "" { + continue + } + + // Dive deep into structs or pointers to structs + kind := field.Type.Kind() + fld := realval.Field(i) + + if kind == reflect.Struct { + if err := g.scanStruct(fld, &field, handler); err != nil { + return err + } + } else if kind == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct { + if fld.IsNil() { + fld.Set(reflect.New(fld.Type().Elem())) + } + + if err := g.scanStruct(reflect.Indirect(fld), &field, handler); err != nil { + return err + } + } + + longname := mtag.Get("long") + shortname := mtag.Get("short") + + // Need at least either a short or long name + if longname == "" && shortname == "" && mtag.Get("ini-name") == "" { + continue + } + + short := rune(0) + rc := utf8.RuneCountInString(shortname) + + if rc > 1 { + return newErrorf(ErrShortNameTooLong, + "short names can only be 1 character long, not `%s'", + shortname) + + } else if rc == 1 { + short, _ = utf8.DecodeRuneInString(shortname) + } + + description := mtag.Get("description") + def := mtag.GetMany("default") + + optionalValue := mtag.GetMany("optional-value") + valueName := mtag.Get("value-name") + defaultMask := mtag.Get("default-mask") + + optional := (mtag.Get("optional") != "") + required := (mtag.Get("required") != "") + + option := &Option{ + Description: description, + ShortName: short, + LongName: longname, + Default: def, + EnvDefaultKey: mtag.Get("env"), + EnvDefaultDelim: mtag.Get("env-delim"), + OptionalArgument: optional, + OptionalValue: optionalValue, + Required: required, + ValueName: valueName, + DefaultMask: defaultMask, + + group: g, + + field: field, + value: realval.Field(i), + tag: mtag, + } + + g.options = append(g.options, option) + } + + return nil +} + +func (g *Group) checkForDuplicateFlags() *Error { + shortNames := make(map[rune]*Option) + longNames := make(map[string]*Option) + + var duplicateError *Error + + g.eachGroup(func(g *Group) { + for _, option := range g.options { + if option.LongName != "" { + longName := option.LongNameWithNamespace() + + if otherOption, ok := longNames[longName]; ok { + duplicateError = newErrorf(ErrDuplicatedFlag, "option `%s' uses the same long name as option `%s'", option, otherOption) + return + } + longNames[longName] = option + } + if option.ShortName != 0 { + if otherOption, ok := shortNames[option.ShortName]; ok { + duplicateError = newErrorf(ErrDuplicatedFlag, "option `%s' uses the same short name as option `%s'", option, otherOption) + return + } + shortNames[option.ShortName] = option + } + } + }) + + return duplicateError +} + +func (g *Group) scanSubGroupHandler(realval reflect.Value, sfield *reflect.StructField) (bool, error) { + mtag := newMultiTag(string(sfield.Tag)) + + if err := mtag.Parse(); err != nil { + return true, err + } + + subgroup := mtag.Get("group") + + if len(subgroup) != 0 { + ptrval := reflect.NewAt(realval.Type(), unsafe.Pointer(realval.UnsafeAddr())) + description := mtag.Get("description") + + group, err := g.AddGroup(subgroup, description, ptrval.Interface()) + if err != nil { + return true, err + } + + group.Namespace = mtag.Get("namespace") + + return true, nil + } + + return false, nil +} + +func (g *Group) scanType(handler scanHandler) error { + // Get all the public fields in the data struct + ptrval := reflect.ValueOf(g.data) + + if ptrval.Type().Kind() != reflect.Ptr { + panic(ErrNotPointerToStruct) + } + + stype := ptrval.Type().Elem() + + if stype.Kind() != reflect.Struct { + panic(ErrNotPointerToStruct) + } + + realval := reflect.Indirect(ptrval) + + if err := g.scanStruct(realval, nil, handler); err != nil { + return err + } + + if err := g.checkForDuplicateFlags(); err != nil { + return err + } + + return nil +} + +func (g *Group) scan() error { + return g.scanType(g.scanSubGroupHandler) +} + +func (g *Group) groupByName(name string) *Group { + if len(name) == 0 { + return g + } + + return g.Find(name) +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/group_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/group_test.go new file mode 100644 index 000000000..b5ed9d492 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/group_test.go @@ -0,0 +1,187 @@ +package flags + +import ( + "testing" +) + +func TestGroupInline(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + Group struct { + G bool `short:"g"` + } `group:"Grouped Options"` + }{} + + p, ret := assertParserSuccess(t, &opts, "-v", "-g") + + assertStringArray(t, ret, []string{}) + + if !opts.Value { + t.Errorf("Expected Value to be true") + } + + if !opts.Group.G { + t.Errorf("Expected Group.G to be true") + } + + if p.Command.Group.Find("Grouped Options") == nil { + t.Errorf("Expected to find group `Grouped Options'") + } +} + +func TestGroupAdd(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + }{} + + var grp = struct { + G bool `short:"g"` + }{} + + p := NewParser(&opts, Default) + g, err := p.AddGroup("Grouped Options", "", &grp) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + return + } + + ret, err := p.ParseArgs([]string{"-v", "-g", "rest"}) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + return + } + + assertStringArray(t, ret, []string{"rest"}) + + if !opts.Value { + t.Errorf("Expected Value to be true") + } + + if !grp.G { + t.Errorf("Expected Group.G to be true") + } + + if p.Command.Group.Find("Grouped Options") != g { + t.Errorf("Expected to find group `Grouped Options'") + } + + if p.Groups()[1] != g { + t.Errorf("Expected group %#v, but got %#v", g, p.Groups()[0]) + } + + if g.Options()[0].ShortName != 'g' { + t.Errorf("Expected short name `g' but got %v", g.Options()[0].ShortName) + } +} + +func TestGroupNestedInline(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + + Group struct { + G bool `short:"g"` + + Nested struct { + N string `long:"n"` + } `group:"Nested Options"` + } `group:"Grouped Options"` + }{} + + p, ret := assertParserSuccess(t, &opts, "-v", "-g", "--n", "n", "rest") + + assertStringArray(t, ret, []string{"rest"}) + + if !opts.Value { + t.Errorf("Expected Value to be true") + } + + if !opts.Group.G { + t.Errorf("Expected Group.G to be true") + } + + assertString(t, opts.Group.Nested.N, "n") + + if p.Command.Group.Find("Grouped Options") == nil { + t.Errorf("Expected to find group `Grouped Options'") + } + + if p.Command.Group.Find("Nested Options") == nil { + t.Errorf("Expected to find group `Nested Options'") + } +} + +func TestGroupNestedInlineNamespace(t *testing.T) { + var opts = struct { + Opt string `long:"opt"` + + Group struct { + Opt string `long:"opt"` + Group struct { + Opt string `long:"opt"` + } `group:"Subsubgroup" namespace:"sap"` + } `group:"Subgroup" namespace:"sip"` + }{} + + p, ret := assertParserSuccess(t, &opts, "--opt", "a", "--sip.opt", "b", "--sip.sap.opt", "c", "rest") + + assertStringArray(t, ret, []string{"rest"}) + + assertString(t, opts.Opt, "a") + assertString(t, opts.Group.Opt, "b") + assertString(t, opts.Group.Group.Opt, "c") + + for _, name := range []string{"Subgroup", "Subsubgroup"} { + if p.Command.Group.Find(name) == nil { + t.Errorf("Expected to find group '%s'", name) + } + } +} + +func TestDuplicateShortFlags(t *testing.T) { + var opts struct { + Verbose []bool `short:"v" long:"verbose" description:"Show verbose debug information"` + Variables []string `short:"v" long:"variable" description:"Set a variable value."` + } + + args := []string{ + "--verbose", + "-v", "123", + "-v", "456", + } + + _, err := ParseArgs(&opts, args) + + if err == nil { + t.Errorf("Expected an error with type ErrDuplicatedFlag") + } else { + err2 := err.(*Error) + if err2.Type != ErrDuplicatedFlag { + t.Errorf("Expected an error with type ErrDuplicatedFlag") + } + } +} + +func TestDuplicateLongFlags(t *testing.T) { + var opts struct { + Test1 []bool `short:"a" long:"testing" description:"Test 1"` + Test2 []string `short:"b" long:"testing" description:"Test 2."` + } + + args := []string{ + "--testing", + } + + _, err := ParseArgs(&opts, args) + + if err == nil { + t.Errorf("Expected an error with type ErrDuplicatedFlag") + } else { + err2 := err.(*Error) + if err2.Type != ErrDuplicatedFlag { + t.Errorf("Expected an error with type ErrDuplicatedFlag") + } + } +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/help.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/help.go new file mode 100644 index 000000000..e26fcd01c --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/help.go @@ -0,0 +1,426 @@ +// Copyright 2012 Jesse van den Kieboom. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flags + +import ( + "bufio" + "bytes" + "fmt" + "io" + "reflect" + "runtime" + "strings" + "unicode/utf8" +) + +type alignmentInfo struct { + maxLongLen int + hasShort bool + hasValueName bool + terminalColumns int + indent bool +} + +const ( + paddingBeforeOption = 2 + distanceBetweenOptionAndDescription = 2 +) + +func (a *alignmentInfo) descriptionStart() int { + ret := a.maxLongLen + distanceBetweenOptionAndDescription + + if a.hasShort { + ret += 2 + } + + if a.maxLongLen > 0 { + ret += 4 + } + + if a.hasValueName { + ret += 3 + } + + return ret +} + +func (a *alignmentInfo) updateLen(name string, indent bool) { + l := utf8.RuneCountInString(name) + + if indent { + l = l + 4 + } + + if l > a.maxLongLen { + a.maxLongLen = l + } +} + +func (p *Parser) getAlignmentInfo() alignmentInfo { + ret := alignmentInfo{ + maxLongLen: 0, + hasShort: false, + hasValueName: false, + terminalColumns: getTerminalColumns(), + } + + if ret.terminalColumns <= 0 { + ret.terminalColumns = 80 + } + + var prevcmd *Command + + p.eachActiveGroup(func(c *Command, grp *Group) { + if c != prevcmd { + for _, arg := range c.args { + ret.updateLen(arg.Name, c != p.Command) + } + } + + for _, info := range grp.options { + if !info.canCli() { + continue + } + + if info.ShortName != 0 { + ret.hasShort = true + } + + if len(info.ValueName) > 0 { + ret.hasValueName = true + } + + ret.updateLen(info.LongNameWithNamespace()+info.ValueName, c != p.Command) + } + }) + + return ret +} + +func (p *Parser) writeHelpOption(writer *bufio.Writer, option *Option, info alignmentInfo) { + line := &bytes.Buffer{} + + prefix := paddingBeforeOption + + if info.indent { + prefix += 4 + } + + line.WriteString(strings.Repeat(" ", prefix)) + + if option.ShortName != 0 { + line.WriteRune(defaultShortOptDelimiter) + line.WriteRune(option.ShortName) + } else if info.hasShort { + line.WriteString(" ") + } + + descstart := info.descriptionStart() + paddingBeforeOption + + if len(option.LongName) > 0 { + if option.ShortName != 0 { + line.WriteString(", ") + } else if info.hasShort { + line.WriteString(" ") + } + + line.WriteString(defaultLongOptDelimiter) + line.WriteString(option.LongNameWithNamespace()) + } + + if option.canArgument() { + line.WriteRune(defaultNameArgDelimiter) + + if len(option.ValueName) > 0 { + line.WriteString(option.ValueName) + } + } + + written := line.Len() + line.WriteTo(writer) + + if option.Description != "" { + dw := descstart - written + writer.WriteString(strings.Repeat(" ", dw)) + + def := "" + defs := option.Default + + if len(option.DefaultMask) != 0 { + if option.DefaultMask != "-" { + def = option.DefaultMask + } + } else if len(defs) == 0 && option.canArgument() { + var showdef bool + + switch option.field.Type.Kind() { + case reflect.Func, reflect.Ptr: + showdef = !option.value.IsNil() + case reflect.Slice, reflect.String, reflect.Array: + showdef = option.value.Len() > 0 + case reflect.Map: + showdef = !option.value.IsNil() && option.value.Len() > 0 + default: + zeroval := reflect.Zero(option.field.Type) + showdef = !reflect.DeepEqual(zeroval.Interface(), option.value.Interface()) + } + + if showdef { + def, _ = convertToString(option.value, option.tag) + } + } else if len(defs) != 0 { + l := len(defs) - 1 + + for i := 0; i < l; i++ { + def += quoteIfNeeded(defs[i]) + ", " + } + + def += quoteIfNeeded(defs[l]) + } + + var envDef string + if option.EnvDefaultKey != "" { + var envPrintable string + if runtime.GOOS == "windows" { + envPrintable = "%" + option.EnvDefaultKey + "%" + } else { + envPrintable = "$" + option.EnvDefaultKey + } + envDef = fmt.Sprintf(" [%s]", envPrintable) + } + + var desc string + + if def != "" { + desc = fmt.Sprintf("%s (%v)%s", option.Description, def, envDef) + } else { + desc = option.Description + envDef + } + + writer.WriteString(wrapText(desc, + info.terminalColumns-descstart, + strings.Repeat(" ", descstart))) + } + + writer.WriteString("\n") +} + +func maxCommandLength(s []*Command) int { + if len(s) == 0 { + return 0 + } + + ret := len(s[0].Name) + + for _, v := range s[1:] { + l := len(v.Name) + + if l > ret { + ret = l + } + } + + return ret +} + +// WriteHelp writes a help message containing all the possible options and +// their descriptions to the provided writer. Note that the HelpFlag parser +// option provides a convenient way to add a -h/--help option group to the +// command line parser which will automatically show the help messages using +// this method. +func (p *Parser) WriteHelp(writer io.Writer) { + if writer == nil { + return + } + + wr := bufio.NewWriter(writer) + aligninfo := p.getAlignmentInfo() + + cmd := p.Command + + for cmd.Active != nil { + cmd = cmd.Active + } + + if p.Name != "" { + wr.WriteString("Usage:\n") + wr.WriteString(" ") + + allcmd := p.Command + + for allcmd != nil { + var usage string + + if allcmd == p.Command { + if len(p.Usage) != 0 { + usage = p.Usage + } else if p.Options&HelpFlag != 0 { + usage = "[OPTIONS]" + } + } else if us, ok := allcmd.data.(Usage); ok { + usage = us.Usage() + } else if allcmd.hasCliOptions() { + usage = fmt.Sprintf("[%s-OPTIONS]", allcmd.Name) + } + + if len(usage) != 0 { + fmt.Fprintf(wr, " %s %s", allcmd.Name, usage) + } else { + fmt.Fprintf(wr, " %s", allcmd.Name) + } + + if len(allcmd.args) > 0 { + fmt.Fprintf(wr, " ") + } + + for i, arg := range allcmd.args { + if i != 0 { + fmt.Fprintf(wr, " ") + } + + name := arg.Name + + if arg.isRemaining() { + name = name + "..." + } + + if !allcmd.ArgsRequired { + fmt.Fprintf(wr, "[%s]", name) + } else { + fmt.Fprintf(wr, "%s", name) + } + } + + if allcmd.Active == nil && len(allcmd.commands) > 0 { + var co, cc string + + if allcmd.SubcommandsOptional { + co, cc = "[", "]" + } else { + co, cc = "<", ">" + } + + if len(allcmd.commands) > 3 { + fmt.Fprintf(wr, " %scommand%s", co, cc) + } else { + subcommands := allcmd.sortedCommands() + names := make([]string, len(subcommands)) + + for i, subc := range subcommands { + names[i] = subc.Name + } + + fmt.Fprintf(wr, " %s%s%s", co, strings.Join(names, " | "), cc) + } + } + + allcmd = allcmd.Active + } + + fmt.Fprintln(wr) + + if len(cmd.LongDescription) != 0 { + fmt.Fprintln(wr) + + t := wrapText(cmd.LongDescription, + aligninfo.terminalColumns, + "") + + fmt.Fprintln(wr, t) + } + } + + c := p.Command + + for c != nil { + printcmd := c != p.Command + + c.eachGroup(func(grp *Group) { + first := true + + // Skip built-in help group for all commands except the top-level + // parser + if grp.isBuiltinHelp && c != p.Command { + return + } + + for _, info := range grp.options { + if !info.canCli() { + continue + } + + if printcmd { + fmt.Fprintf(wr, "\n[%s command options]\n", c.Name) + aligninfo.indent = true + printcmd = false + } + + if first && cmd.Group != grp { + fmt.Fprintln(wr) + + if aligninfo.indent { + wr.WriteString(" ") + } + + fmt.Fprintf(wr, "%s:\n", grp.ShortDescription) + first = false + } + + p.writeHelpOption(wr, info, aligninfo) + } + }) + + if len(c.args) > 0 { + if c == p.Command { + fmt.Fprintf(wr, "\nArguments:\n") + } else { + fmt.Fprintf(wr, "\n[%s command arguments]\n", c.Name) + } + + maxlen := aligninfo.descriptionStart() + + for _, arg := range c.args { + prefix := strings.Repeat(" ", paddingBeforeOption) + fmt.Fprintf(wr, "%s%s", prefix, arg.Name) + + if len(arg.Description) > 0 { + align := strings.Repeat(" ", maxlen-len(arg.Name)-1) + fmt.Fprintf(wr, ":%s%s", align, arg.Description) + } + + fmt.Fprintln(wr) + } + } + + c = c.Active + } + + scommands := cmd.sortedCommands() + + if len(scommands) > 0 { + maxnamelen := maxCommandLength(scommands) + + fmt.Fprintln(wr) + fmt.Fprintln(wr, "Available commands:") + + for _, c := range scommands { + fmt.Fprintf(wr, " %s", c.Name) + + if len(c.ShortDescription) > 0 { + pad := strings.Repeat(" ", maxnamelen-len(c.Name)) + fmt.Fprintf(wr, "%s %s", pad, c.ShortDescription) + + if len(c.Aliases) > 0 { + fmt.Fprintf(wr, " (aliases: %s)", strings.Join(c.Aliases, ", ")) + } + + } + + fmt.Fprintln(wr) + } + } + + wr.Flush() +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/help_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/help_test.go new file mode 100644 index 000000000..c80288301 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/help_test.go @@ -0,0 +1,292 @@ +package flags + +import ( + "bytes" + "fmt" + "os" + "runtime" + "testing" + "time" +) + +type helpOptions struct { + Verbose []bool `short:"v" long:"verbose" description:"Show verbose debug information" ini-name:"verbose"` + Call func(string) `short:"c" description:"Call phone number" ini-name:"call"` + PtrSlice []*string `long:"ptrslice" description:"A slice of pointers to string"` + EmptyDescription bool `long:"empty-description"` + + Default string `long:"default" default:"Some\nvalue" description:"Test default value"` + DefaultArray []string `long:"default-array" default:"Some value" default:"Other\tvalue" description:"Test default array value"` + DefaultMap map[string]string `long:"default-map" default:"some:value" default:"another:value" description:"Testdefault map value"` + EnvDefault1 string `long:"env-default1" default:"Some value" env:"ENV_DEFAULT" description:"Test env-default1 value"` + EnvDefault2 string `long:"env-default2" env:"ENV_DEFAULT" description:"Test env-default2 value"` + OptionWithArgName string `long:"opt-with-arg-name" value-name:"something" description:"Option with named argument"` + + OnlyIni string `ini-name:"only-ini" description:"Option only available in ini"` + + Other struct { + StringSlice []string `short:"s" default:"some" default:"value" description:"A slice of strings"` + IntMap map[string]int `long:"intmap" default:"a:1" description:"A map from string to int" ini-name:"int-map"` + } `group:"Other Options"` + + Group struct { + Opt string `long:"opt" description:"This is a subgroup option"` + + Group struct { + Opt string `long:"opt" description:"This is a subsubgroup option"` + } `group:"Subsubgroup" namespace:"sap"` + } `group:"Subgroup" namespace:"sip"` + + Command struct { + ExtraVerbose []bool `long:"extra-verbose" description:"Use for extra verbosity"` + } `command:"command" alias:"cm" alias:"cmd" description:"A command"` + + Args struct { + Filename string `positional-arg-name:"filename" description:"A filename"` + Number int `positional-arg-name:"num" description:"A number"` + } `positional-args:"yes"` +} + +func TestHelp(t *testing.T) { + oldEnv := EnvSnapshot() + defer oldEnv.Restore() + os.Setenv("ENV_DEFAULT", "env-def") + + var opts helpOptions + p := NewNamedParser("TestHelp", HelpFlag) + p.AddGroup("Application Options", "The application options", &opts) + + _, err := p.ParseArgs([]string{"--help"}) + + if err == nil { + t.Fatalf("Expected help error") + } + + if e, ok := err.(*Error); !ok { + t.Fatalf("Expected flags.Error, but got %T", err) + } else { + if e.Type != ErrHelp { + t.Errorf("Expected flags.ErrHelp type, but got %s", e.Type) + } + + var expected string + + if runtime.GOOS == "windows" { + expected = `Usage: + TestHelp [OPTIONS] [filename] [num] + +Application Options: + /v, /verbose Show verbose debug information + /c: Call phone number + /ptrslice: A slice of pointers to string + /empty-description + /default: Test default value ("Some\nvalue") + /default-array: Test default array value (Some value, "Other\tvalue") + /default-map: Testdefault map value (some:value, another:value) + /env-default1: Test env-default1 value (Some value) [%ENV_DEFAULT%] + /env-default2: Test env-default2 value [%ENV_DEFAULT%] + /opt-with-arg-name:something Option with named argument + +Other Options: + /s: A slice of strings (some, value) + /intmap: A map from string to int (a:1) + +Subgroup: + /sip.opt: This is a subgroup option + +Subsubgroup: + /sip.sap.opt: This is a subsubgroup option + +Help Options: + /? Show this help message + /h, /help Show this help message + +Arguments: + filename: A filename + num: A number + +Available commands: + command A command (aliases: cm, cmd) +` + } else { + expected = `Usage: + TestHelp [OPTIONS] [filename] [num] + +Application Options: + -v, --verbose Show verbose debug information + -c= Call phone number + --ptrslice= A slice of pointers to string + --empty-description + --default= Test default value ("Some\nvalue") + --default-array= Test default array value (Some value, + "Other\tvalue") + --default-map= Testdefault map value (some:value, + another:value) + --env-default1= Test env-default1 value (Some value) + [$ENV_DEFAULT] + --env-default2= Test env-default2 value [$ENV_DEFAULT] + --opt-with-arg-name=something Option with named argument + +Other Options: + -s= A slice of strings (some, value) + --intmap= A map from string to int (a:1) + +Subgroup: + --sip.opt= This is a subgroup option + +Subsubgroup: + --sip.sap.opt= This is a subsubgroup option + +Help Options: + -h, --help Show this help message + +Arguments: + filename: A filename + num: A number + +Available commands: + command A command (aliases: cm, cmd) +` + } + + assertDiff(t, e.Message, expected, "help message") + } +} + +func TestMan(t *testing.T) { + oldEnv := EnvSnapshot() + defer oldEnv.Restore() + os.Setenv("ENV_DEFAULT", "env-def") + + var opts helpOptions + p := NewNamedParser("TestMan", HelpFlag) + p.ShortDescription = "Test manpage generation" + p.LongDescription = "This is a somewhat `longer' description of what this does" + p.AddGroup("Application Options", "The application options", &opts) + + p.Commands()[0].LongDescription = "Longer `command' description" + + var buf bytes.Buffer + p.WriteManPage(&buf) + + got := buf.String() + + tt := time.Now() + + expected := fmt.Sprintf(`.TH TestMan 1 "%s" +.SH NAME +TestMan \- Test manpage generation +.SH SYNOPSIS +\fBTestMan\fP [OPTIONS] +.SH DESCRIPTION +This is a somewhat \fBlonger\fP description of what this does +.SH OPTIONS +.TP +\fB-v, --verbose\fP +Show verbose debug information +.TP +\fB-c\fP +Call phone number +.TP +\fB--ptrslice\fP +A slice of pointers to string +.TP +\fB--empty-description\fP +.TP +\fB--default\fP +Test default value +.TP +\fB--default-array\fP +Test default array value +.TP +\fB--default-map\fP +Testdefault map value +.TP +\fB--env-default1\fP +Test env-default1 value +.TP +\fB--env-default2\fP +Test env-default2 value +.TP +\fB--opt-with-arg-name\fP +Option with named argument +.TP +\fB-s\fP +A slice of strings +.TP +\fB--intmap\fP +A map from string to int +.TP +\fB--sip.opt\fP +This is a subgroup option +.TP +\fB--sip.sap.opt\fP +This is a subsubgroup option +.SH COMMANDS +.SS command +A command + +Longer \fBcommand\fP description + +\fBUsage\fP: TestMan [OPTIONS] command [command-OPTIONS] + + +\fBAliases\fP: cm, cmd + +.TP +\fB--extra-verbose\fP +Use for extra verbosity +`, tt.Format("2 January 2006")) + + assertDiff(t, got, expected, "man page") +} + +type helpCommandNoOptions struct { + Command struct { + } `command:"command" description:"A command"` +} + +func TestHelpCommand(t *testing.T) { + oldEnv := EnvSnapshot() + defer oldEnv.Restore() + os.Setenv("ENV_DEFAULT", "env-def") + + var opts helpCommandNoOptions + p := NewNamedParser("TestHelpCommand", HelpFlag) + p.AddGroup("Application Options", "The application options", &opts) + + _, err := p.ParseArgs([]string{"command", "--help"}) + + if err == nil { + t.Fatalf("Expected help error") + } + + if e, ok := err.(*Error); !ok { + t.Fatalf("Expected flags.Error, but got %T", err) + } else { + if e.Type != ErrHelp { + t.Errorf("Expected flags.ErrHelp type, but got %s", e.Type) + } + + var expected string + + if runtime.GOOS == "windows" { + expected = `Usage: + TestHelpCommand [OPTIONS] command + +Help Options: + /? Show this help message + /h, /help Show this help message +` + } else { + expected = `Usage: + TestHelpCommand [OPTIONS] command + +Help Options: + -h, --help Show this help message +` + } + + assertDiff(t, e.Message, expected, "help message") + } +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/ini.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/ini.go new file mode 100644 index 000000000..722505225 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/ini.go @@ -0,0 +1,140 @@ +package flags + +import ( + "fmt" + "io" +) + +// IniError contains location information on where an error occured. +type IniError struct { + // The error message. + Message string + + // The filename of the file in which the error occurred. + File string + + // The line number at which the error occurred. + LineNumber uint +} + +// Error provides a "file:line: message" formatted message of the ini error. +func (x *IniError) Error() string { + return fmt.Sprintf( + "%s:%d: %s", + x.File, + x.LineNumber, + x.Message, + ) +} + +// IniOptions for writing +type IniOptions uint + +const ( + // IniNone indicates no options. + IniNone IniOptions = 0 + + // IniIncludeDefaults indicates that default values should be written. + IniIncludeDefaults = 1 << iota + + // IniCommentDefaults indicates that if IniIncludeDefaults is used + // options with default values are written but commented out. + IniCommentDefaults + + // IniIncludeComments indicates that comments containing the description + // of an option should be written. + IniIncludeComments + + // IniDefault provides a default set of options. + IniDefault = IniIncludeComments +) + +// IniParser is a utility to read and write flags options from and to ini +// formatted strings. +type IniParser struct { + parser *Parser +} + +// NewIniParser creates a new ini parser for a given Parser. +func NewIniParser(p *Parser) *IniParser { + return &IniParser{ + parser: p, + } +} + +// IniParse is a convenience function to parse command line options with default +// settings from an ini formatted file. The provided data is a pointer to a struct +// representing the default option group (named "Application Options"). For +// more control, use flags.NewParser. +func IniParse(filename string, data interface{}) error { + p := NewParser(data, Default) + + return NewIniParser(p).ParseFile(filename) +} + +// ParseFile parses flags from an ini formatted file. See Parse for more +// information on the ini file format. The returned errors can be of the type +// flags.Error or flags.IniError. +func (i *IniParser) ParseFile(filename string) error { + i.parser.clearIsSet() + + ini, err := readIniFromFile(filename) + + if err != nil { + return err + } + + return i.parse(ini) +} + +// Parse parses flags from an ini format. You can use ParseFile as a +// convenience function to parse from a filename instead of a general +// io.Reader. +// +// The format of the ini file is as follows: +// +// [Option group name] +// option = value +// +// Each section in the ini file represents an option group or command in the +// flags parser. The default flags parser option group (i.e. when using +// flags.Parse) is named 'Application Options'. The ini option name is matched +// in the following order: +// +// 1. Compared to the ini-name tag on the option struct field (if present) +// 2. Compared to the struct field name +// 3. Compared to the option long name (if present) +// 4. Compared to the option short name (if present) +// +// Sections for nested groups and commands can be addressed using a dot `.' +// namespacing notation (i.e [subcommand.Options]). Group section names are +// matched case insensitive. +// +// The returned errors can be of the type flags.Error or flags.IniError. +func (i *IniParser) Parse(reader io.Reader) error { + i.parser.clearIsSet() + + ini, err := readIni(reader, "") + + if err != nil { + return err + } + + return i.parse(ini) +} + +// WriteFile writes the flags as ini format into a file. See WriteIni +// for more information. The returned error occurs when the specified file +// could not be opened for writing. +func (i *IniParser) WriteFile(filename string, options IniOptions) error { + return writeIniToFile(i, filename, options) +} + +// Write writes the current values of all the flags to an ini format. +// See Parse for more information on the ini file format. You typically +// call this only after settings have been parsed since the default values of each +// option are stored just before parsing the flags (this is only relevant when +// IniIncludeDefaults is _not_ set in options). +func (i *IniParser) Write(writer io.Writer, options IniOptions) { + writeIni(i, writer, options) +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/ini_private.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/ini_private.go new file mode 100644 index 000000000..887aa7679 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/ini_private.go @@ -0,0 +1,452 @@ +package flags + +import ( + "bufio" + "fmt" + "io" + "os" + "reflect" + "sort" + "strconv" + "strings" +) + +type iniValue struct { + Name string + Value string + Quoted bool + LineNumber uint +} + +type iniSection []iniValue +type ini struct { + File string + Sections map[string]iniSection +} + +func readFullLine(reader *bufio.Reader) (string, error) { + var line []byte + + for { + l, more, err := reader.ReadLine() + + if err != nil { + return "", err + } + + if line == nil && !more { + return string(l), nil + } + + line = append(line, l...) + + if !more { + break + } + } + + return string(line), nil +} + +func optionIniName(option *Option) string { + name := option.tag.Get("_read-ini-name") + + if len(name) != 0 { + return name + } + + name = option.tag.Get("ini-name") + + if len(name) != 0 { + return name + } + + return option.field.Name +} + +func writeGroupIni(cmd *Command, group *Group, namespace string, writer io.Writer, options IniOptions) { + var sname string + + if len(namespace) != 0 { + sname = namespace + } + + if cmd.Group != group && len(group.ShortDescription) != 0 { + if len(sname) != 0 { + sname += "." + } + + sname += group.ShortDescription + } + + sectionwritten := false + comments := (options & IniIncludeComments) != IniNone + + for _, option := range group.options { + if option.isFunc() { + continue + } + + if len(option.tag.Get("no-ini")) != 0 { + continue + } + + val := option.value + + if (options&IniIncludeDefaults) == IniNone && option.valueIsDefault() { + continue + } + + if !sectionwritten { + fmt.Fprintf(writer, "[%s]\n", sname) + sectionwritten = true + } + + if comments && len(option.Description) != 0 { + fmt.Fprintf(writer, "; %s\n", option.Description) + } + + oname := optionIniName(option) + + commentOption := (options&(IniIncludeDefaults|IniCommentDefaults)) == IniIncludeDefaults|IniCommentDefaults && option.valueIsDefault() + + kind := val.Type().Kind() + switch kind { + case reflect.Slice: + kind = val.Type().Elem().Kind() + + if val.Len() == 0 { + writeOption(writer, oname, kind, "", "", true, option.iniQuote) + } else { + for idx := 0; idx < val.Len(); idx++ { + v, _ := convertToString(val.Index(idx), option.tag) + + writeOption(writer, oname, kind, "", v, commentOption, option.iniQuote) + } + } + case reflect.Map: + kind = val.Type().Elem().Kind() + + if val.Len() == 0 { + writeOption(writer, oname, kind, "", "", true, option.iniQuote) + } else { + mkeys := val.MapKeys() + keys := make([]string, len(val.MapKeys())) + kkmap := make(map[string]reflect.Value) + + for i, k := range mkeys { + keys[i], _ = convertToString(k, option.tag) + kkmap[keys[i]] = k + } + + sort.Strings(keys) + + for _, k := range keys { + v, _ := convertToString(val.MapIndex(kkmap[k]), option.tag) + + writeOption(writer, oname, kind, k, v, commentOption, option.iniQuote) + } + } + default: + v, _ := convertToString(val, option.tag) + + writeOption(writer, oname, kind, "", v, commentOption, option.iniQuote) + } + + if comments { + fmt.Fprintln(writer) + } + } + + if sectionwritten && !comments { + fmt.Fprintln(writer) + } +} + +func writeOption(writer io.Writer, optionName string, optionType reflect.Kind, optionKey string, optionValue string, commentOption bool, forceQuote bool) { + if forceQuote || (optionType == reflect.String && !isPrint(optionValue)) { + optionValue = strconv.Quote(optionValue) + } + + comment := "" + if commentOption { + comment = "; " + } + + fmt.Fprintf(writer, "%s%s =", comment, optionName) + + if optionKey != "" { + fmt.Fprintf(writer, " %s:%s", optionKey, optionValue) + } else if optionValue != "" { + fmt.Fprintf(writer, " %s", optionValue) + } + + fmt.Fprintln(writer) +} + +func writeCommandIni(command *Command, namespace string, writer io.Writer, options IniOptions) { + command.eachGroup(func(group *Group) { + writeGroupIni(command, group, namespace, writer, options) + }) + + for _, c := range command.commands { + var nns string + + if len(namespace) != 0 { + nns = c.Name + "." + nns + } else { + nns = c.Name + } + + writeCommandIni(c, nns, writer, options) + } +} + +func writeIni(parser *IniParser, writer io.Writer, options IniOptions) { + writeCommandIni(parser.parser.Command, "", writer, options) +} + +func writeIniToFile(parser *IniParser, filename string, options IniOptions) error { + file, err := os.Create(filename) + + if err != nil { + return err + } + + defer file.Close() + + writeIni(parser, file, options) + + return nil +} + +func readIniFromFile(filename string) (*ini, error) { + file, err := os.Open(filename) + + if err != nil { + return nil, err + } + + defer file.Close() + + return readIni(file, filename) +} + +func readIni(contents io.Reader, filename string) (*ini, error) { + ret := &ini{ + File: filename, + Sections: make(map[string]iniSection), + } + + reader := bufio.NewReader(contents) + + // Empty global section + section := make(iniSection, 0, 10) + sectionname := "" + + ret.Sections[sectionname] = section + + var lineno uint + + for { + line, err := readFullLine(reader) + + if err == io.EOF { + break + } else if err != nil { + return nil, err + } + + lineno++ + line = strings.TrimSpace(line) + + // Skip empty lines and lines starting with ; (comments) + if len(line) == 0 || line[0] == ';' || line[0] == '#' { + continue + } + + if line[0] == '[' { + if line[0] != '[' || line[len(line)-1] != ']' { + return nil, &IniError{ + Message: "malformed section header", + File: filename, + LineNumber: lineno, + } + } + + name := strings.TrimSpace(line[1 : len(line)-1]) + + if len(name) == 0 { + return nil, &IniError{ + Message: "empty section name", + File: filename, + LineNumber: lineno, + } + } + + sectionname = name + section = ret.Sections[name] + + if section == nil { + section = make(iniSection, 0, 10) + ret.Sections[name] = section + } + + continue + } + + // Parse option here + keyval := strings.SplitN(line, "=", 2) + + if len(keyval) != 2 { + return nil, &IniError{ + Message: fmt.Sprintf("malformed key=value (%s)", line), + File: filename, + LineNumber: lineno, + } + } + + name := strings.TrimSpace(keyval[0]) + value := strings.TrimSpace(keyval[1]) + quoted := false + + if len(value) != 0 && value[0] == '"' { + if v, err := strconv.Unquote(value); err == nil { + value = v + + quoted = true + } else { + return nil, &IniError{ + Message: err.Error(), + File: filename, + LineNumber: lineno, + } + } + } + + section = append(section, iniValue{ + Name: name, + Value: value, + Quoted: quoted, + LineNumber: lineno, + }) + + ret.Sections[sectionname] = section + } + + return ret, nil +} + +func (i *IniParser) matchingGroups(name string) []*Group { + if len(name) == 0 { + var ret []*Group + + i.parser.eachGroup(func(g *Group) { + ret = append(ret, g) + }) + + return ret + } + + g := i.parser.groupByName(name) + + if g != nil { + return []*Group{g} + } + + return nil +} + +func (i *IniParser) parse(ini *ini) error { + p := i.parser + + var quotesLookup = make(map[*Option]bool) + + for name, section := range ini.Sections { + groups := i.matchingGroups(name) + + if len(groups) == 0 { + return newErrorf(ErrUnknownGroup, "could not find option group `%s'", name) + } + + for _, inival := range section { + var opt *Option + + for _, group := range groups { + opt = group.optionByName(inival.Name, func(o *Option, n string) bool { + return strings.ToLower(o.tag.Get("ini-name")) == strings.ToLower(n) + }) + + if opt != nil && len(opt.tag.Get("no-ini")) != 0 { + opt = nil + } + + if opt != nil { + break + } + } + + if opt == nil { + if (p.Options & IgnoreUnknown) == None { + return &IniError{ + Message: fmt.Sprintf("unknown option: %s", inival.Name), + File: ini.File, + LineNumber: inival.LineNumber, + } + } + + continue + } + + pval := &inival.Value + + if !opt.canArgument() && len(inival.Value) == 0 { + pval = nil + } else { + if opt.value.Type().Kind() == reflect.Map { + parts := strings.SplitN(inival.Value, ":", 2) + + // only handle unquoting + if len(parts) == 2 && parts[1][0] == '"' { + if v, err := strconv.Unquote(parts[1]); err == nil { + parts[1] = v + + inival.Quoted = true + } else { + return &IniError{ + Message: err.Error(), + File: ini.File, + LineNumber: inival.LineNumber, + } + } + + s := parts[0] + ":" + parts[1] + + pval = &s + } + } + } + + if err := opt.set(pval); err != nil { + return &IniError{ + Message: err.Error(), + File: ini.File, + LineNumber: inival.LineNumber, + } + } + + // either all INI values are quoted or only values who need quoting + if _, ok := quotesLookup[opt]; !inival.Quoted || !ok { + quotesLookup[opt] = inival.Quoted + } + + opt.tag.Set("_read-ini-name", inival.Name) + } + } + + for opt, quoted := range quotesLookup { + opt.iniQuote = quoted + } + + return nil +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/ini_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/ini_test.go new file mode 100644 index 000000000..215b75795 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/ini_test.go @@ -0,0 +1,767 @@ +package flags + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "reflect" + "strings" + "testing" +) + +func TestWriteIni(t *testing.T) { + oldEnv := EnvSnapshot() + defer oldEnv.Restore() + os.Setenv("ENV_DEFAULT", "env-def") + + var tests = []struct { + args []string + options IniOptions + expected string + }{ + { + []string{"-vv", "--intmap=a:2", "--intmap", "b:3", "filename", "0", "command"}, + IniDefault, + `[Application Options] +; Show verbose debug information +verbose = true +verbose = true + +; Test env-default1 value +EnvDefault1 = env-def + +; Test env-default2 value +EnvDefault2 = env-def + +[Other Options] +; A map from string to int +int-map = a:2 +int-map = b:3 + +`, + }, + { + []string{"-vv", "--intmap=a:2", "--intmap", "b:3", "filename", "0", "command"}, + IniDefault | IniIncludeDefaults, + `[Application Options] +; Show verbose debug information +verbose = true +verbose = true + +; A slice of pointers to string +; PtrSlice = + +EmptyDescription = false + +; Test default value +Default = "Some\nvalue" + +; Test default array value +DefaultArray = Some value +DefaultArray = "Other\tvalue" + +; Testdefault map value +DefaultMap = another:value +DefaultMap = some:value + +; Test env-default1 value +EnvDefault1 = env-def + +; Test env-default2 value +EnvDefault2 = env-def + +; Option with named argument +OptionWithArgName = + +; Option only available in ini +only-ini = + +[Other Options] +; A slice of strings +StringSlice = some +StringSlice = value + +; A map from string to int +int-map = a:2 +int-map = b:3 + +[Subgroup] +; This is a subgroup option +Opt = + +[Subsubgroup] +; This is a subsubgroup option +Opt = + +[command] +; Use for extra verbosity +; ExtraVerbose = + +`, + }, + { + []string{"filename", "0", "command"}, + IniDefault | IniIncludeDefaults | IniCommentDefaults, + `[Application Options] +; Show verbose debug information +; verbose = + +; A slice of pointers to string +; PtrSlice = + +; EmptyDescription = false + +; Test default value +; Default = "Some\nvalue" + +; Test default array value +; DefaultArray = Some value +; DefaultArray = "Other\tvalue" + +; Testdefault map value +; DefaultMap = another:value +; DefaultMap = some:value + +; Test env-default1 value +EnvDefault1 = env-def + +; Test env-default2 value +EnvDefault2 = env-def + +; Option with named argument +; OptionWithArgName = + +; Option only available in ini +; only-ini = + +[Other Options] +; A slice of strings +; StringSlice = some +; StringSlice = value + +; A map from string to int +; int-map = a:1 + +[Subgroup] +; This is a subgroup option +; Opt = + +[Subsubgroup] +; This is a subsubgroup option +; Opt = + +[command] +; Use for extra verbosity +; ExtraVerbose = + +`, + }, + { + []string{"--default=New value", "--default-array=New value", "--default-map=new:value", "filename", "0", "command"}, + IniDefault | IniIncludeDefaults | IniCommentDefaults, + `[Application Options] +; Show verbose debug information +; verbose = + +; A slice of pointers to string +; PtrSlice = + +; EmptyDescription = false + +; Test default value +Default = New value + +; Test default array value +DefaultArray = New value + +; Testdefault map value +DefaultMap = new:value + +; Test env-default1 value +EnvDefault1 = env-def + +; Test env-default2 value +EnvDefault2 = env-def + +; Option with named argument +; OptionWithArgName = + +; Option only available in ini +; only-ini = + +[Other Options] +; A slice of strings +; StringSlice = some +; StringSlice = value + +; A map from string to int +; int-map = a:1 + +[Subgroup] +; This is a subgroup option +; Opt = + +[Subsubgroup] +; This is a subsubgroup option +; Opt = + +[command] +; Use for extra verbosity +; ExtraVerbose = + +`, + }, + } + + for _, test := range tests { + var opts helpOptions + + p := NewNamedParser("TestIni", Default) + p.AddGroup("Application Options", "The application options", &opts) + + _, err := p.ParseArgs(test.args) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + inip := NewIniParser(p) + + var b bytes.Buffer + inip.Write(&b, test.options) + + got := b.String() + expected := test.expected + + msg := fmt.Sprintf("with arguments %+v and ini options %b", test.args, test.options) + assertDiff(t, got, expected, msg) + } +} + +func TestReadIni(t *testing.T) { + var opts helpOptions + + p := NewNamedParser("TestIni", Default) + p.AddGroup("Application Options", "The application options", &opts) + + inip := NewIniParser(p) + + inic := ` +; Show verbose debug information +verbose = true +verbose = true + +DefaultMap = another:"value\n1" +DefaultMap = some:value 2 + +[Application Options] +; A slice of pointers to string +; PtrSlice = + +; Test default value +Default = "New\nvalue" + +; Test env-default1 value +EnvDefault1 = New value + +[Other Options] +# A slice of strings +StringSlice = "some\nvalue" +StringSlice = another value + +; A map from string to int +int-map = a:2 +int-map = b:3 + +` + + b := strings.NewReader(inic) + err := inip.Parse(b) + + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + assertBoolArray(t, opts.Verbose, []bool{true, true}) + + if v := map[string]string{"another": "value\n1", "some": "value 2"}; !reflect.DeepEqual(opts.DefaultMap, v) { + t.Fatalf("Expected %#v for DefaultMap but got %#v", v, opts.DefaultMap) + } + + assertString(t, opts.Default, "New\nvalue") + + assertString(t, opts.EnvDefault1, "New value") + + assertStringArray(t, opts.Other.StringSlice, []string{"some\nvalue", "another value"}) + + if v, ok := opts.Other.IntMap["a"]; !ok { + t.Errorf("Expected \"a\" in Other.IntMap") + } else if v != 2 { + t.Errorf("Expected Other.IntMap[\"a\"] = 2, but got %v", v) + } + + if v, ok := opts.Other.IntMap["b"]; !ok { + t.Errorf("Expected \"b\" in Other.IntMap") + } else if v != 3 { + t.Errorf("Expected Other.IntMap[\"b\"] = 3, but got %v", v) + } +} + +func TestReadAndWriteIni(t *testing.T) { + var tests = []struct { + options IniOptions + read string + write string + }{ + { + IniIncludeComments, + `[Application Options] +; Show verbose debug information +verbose = true +verbose = true + +; Test default value +Default = "quote me" + +; Test default array value +DefaultArray = 1 +DefaultArray = "2" +DefaultArray = 3 + +; Testdefault map value +; DefaultMap = + +; Test env-default1 value +EnvDefault1 = env-def + +; Test env-default2 value +EnvDefault2 = env-def + +[Other Options] +; A slice of strings +; StringSlice = + +; A map from string to int +int-map = a:2 +int-map = b:"3" + +`, + `[Application Options] +; Show verbose debug information +verbose = true +verbose = true + +; Test default value +Default = "quote me" + +; Test default array value +DefaultArray = 1 +DefaultArray = 2 +DefaultArray = 3 + +; Testdefault map value +; DefaultMap = + +; Test env-default1 value +EnvDefault1 = env-def + +; Test env-default2 value +EnvDefault2 = env-def + +[Other Options] +; A slice of strings +; StringSlice = + +; A map from string to int +int-map = a:2 +int-map = b:3 + +`, + }, + { + IniIncludeComments, + `[Application Options] +; Show verbose debug information +verbose = true +verbose = true + +; Test default value +Default = "quote me" + +; Test default array value +DefaultArray = "1" +DefaultArray = "2" +DefaultArray = "3" + +; Testdefault map value +; DefaultMap = + +; Test env-default1 value +EnvDefault1 = env-def + +; Test env-default2 value +EnvDefault2 = env-def + +[Other Options] +; A slice of strings +; StringSlice = + +; A map from string to int +int-map = a:"2" +int-map = b:"3" + +`, + `[Application Options] +; Show verbose debug information +verbose = true +verbose = true + +; Test default value +Default = "quote me" + +; Test default array value +DefaultArray = "1" +DefaultArray = "2" +DefaultArray = "3" + +; Testdefault map value +; DefaultMap = + +; Test env-default1 value +EnvDefault1 = env-def + +; Test env-default2 value +EnvDefault2 = env-def + +[Other Options] +; A slice of strings +; StringSlice = + +; A map from string to int +int-map = a:"2" +int-map = b:"3" + +`, + }, + } + + for _, test := range tests { + var opts helpOptions + + p := NewNamedParser("TestIni", Default) + p.AddGroup("Application Options", "The application options", &opts) + + inip := NewIniParser(p) + + read := strings.NewReader(test.read) + err := inip.Parse(read) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + var write bytes.Buffer + inip.Write(&write, test.options) + + got := write.String() + + msg := fmt.Sprintf("with ini options %b", test.options) + assertDiff(t, got, test.write, msg) + } +} + +func TestReadIniWrongQuoting(t *testing.T) { + var tests = []struct { + iniFile string + lineNumber uint + }{ + { + iniFile: `Default = "New\nvalue`, + lineNumber: 1, + }, + { + iniFile: `StringSlice = "New\nvalue`, + lineNumber: 1, + }, + { + iniFile: `StringSlice = "New\nvalue" + StringSlice = "Second\nvalue`, + lineNumber: 2, + }, + { + iniFile: `DefaultMap = some:"value`, + lineNumber: 1, + }, + { + iniFile: `DefaultMap = some:value + DefaultMap = another:"value`, + lineNumber: 2, + }, + } + + for _, test := range tests { + var opts helpOptions + + p := NewNamedParser("TestIni", Default) + p.AddGroup("Application Options", "The application options", &opts) + + inip := NewIniParser(p) + + inic := test.iniFile + + b := strings.NewReader(inic) + err := inip.Parse(b) + + if err == nil { + t.Fatalf("Expect error") + } + + iniError := err.(*IniError) + + if iniError.LineNumber != test.lineNumber { + t.Fatalf("Expect error on line %d", test.lineNumber) + } + } +} + +func TestIniCommands(t *testing.T) { + var opts struct { + Value string `short:"v" long:"value"` + + Add struct { + Name int `short:"n" long:"name" ini-name:"AliasName"` + + Other struct { + O string `short:"o" long:"other"` + } `group:"Other Options"` + } `command:"add"` + } + + p := NewNamedParser("TestIni", Default) + p.AddGroup("Application Options", "The application options", &opts) + + inip := NewIniParser(p) + + inic := `[Application Options] +value = some value + +[add] +AliasName = 5 + +[add.Other Options] +other = subgroup + +` + + b := strings.NewReader(inic) + err := inip.Parse(b) + + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + assertString(t, opts.Value, "some value") + + if opts.Add.Name != 5 { + t.Errorf("Expected opts.Add.Name to be 5, but got %v", opts.Add.Name) + } + + assertString(t, opts.Add.Other.O, "subgroup") + + // Test writing it back + buf := &bytes.Buffer{} + + inip.Write(buf, IniDefault) + + assertDiff(t, buf.String(), inic, "ini contents") +} + +func TestIniNoIni(t *testing.T) { + var opts struct { + NoValue string `short:"n" long:"novalue" no-ini:"yes"` + Value string `short:"v" long:"value"` + } + + p := NewNamedParser("TestIni", Default) + p.AddGroup("Application Options", "The application options", &opts) + + inip := NewIniParser(p) + + // read INI + inic := `[Application Options] +novalue = some value +value = some other value +` + + b := strings.NewReader(inic) + err := inip.Parse(b) + + if err == nil { + t.Fatalf("Expected error") + } + + iniError := err.(*IniError) + + if v := uint(2); iniError.LineNumber != v { + t.Errorf("Expected opts.Add.Name to be %d, but got %d", v, iniError.LineNumber) + } + + if v := "unknown option: novalue"; iniError.Message != v { + t.Errorf("Expected opts.Add.Name to be %s, but got %s", v, iniError.Message) + } + + // write INI + opts.NoValue = "some value" + opts.Value = "some other value" + + file, err := ioutil.TempFile("", "") + if err != nil { + t.Fatalf("Cannot create temporary file: %s", err) + } + defer os.Remove(file.Name()) + + err = inip.WriteFile(file.Name(), IniIncludeDefaults) + if err != nil { + t.Fatalf("Could not write ini file: %s", err) + } + + found, err := ioutil.ReadFile(file.Name()) + if err != nil { + t.Fatalf("Could not read written ini file: %s", err) + } + + expected := "[Application Options]\nValue = some other value\n\n" + + assertDiff(t, string(found), expected, "ini content") +} + +func TestIniParse(t *testing.T) { + file, err := ioutil.TempFile("", "") + if err != nil { + t.Fatalf("Cannot create temporary file: %s", err) + } + defer os.Remove(file.Name()) + + _, err = file.WriteString("value = 123") + if err != nil { + t.Fatalf("Cannot write to temporary file: %s", err) + } + + file.Close() + + var opts struct { + Value int `long:"value"` + } + + err = IniParse(file.Name(), &opts) + if err != nil { + t.Fatalf("Could not parse ini: %s", err) + } + + if opts.Value != 123 { + t.Fatalf("Expected Value to be \"123\" but was \"%d\"", opts.Value) + } +} + +func TestWriteFile(t *testing.T) { + file, err := ioutil.TempFile("", "") + if err != nil { + t.Fatalf("Cannot create temporary file: %s", err) + } + defer os.Remove(file.Name()) + + var opts struct { + Value int `long:"value"` + } + + opts.Value = 123 + + p := NewParser(&opts, Default) + ini := NewIniParser(p) + + err = ini.WriteFile(file.Name(), IniIncludeDefaults) + if err != nil { + t.Fatalf("Could not write ini file: %s", err) + } + + found, err := ioutil.ReadFile(file.Name()) + if err != nil { + t.Fatalf("Could not read written ini file: %s", err) + } + + expected := "[Application Options]\nValue = 123\n\n" + + assertDiff(t, string(found), expected, "ini content") +} + +func TestOverwriteRequiredOptions(t *testing.T) { + var tests = []struct { + args []string + expected []string + }{ + { + args: []string{"--value", "from CLI"}, + expected: []string{ + "from CLI", + "from default", + }, + }, + { + args: []string{"--value", "from CLI", "--default", "from CLI"}, + expected: []string{ + "from CLI", + "from CLI", + }, + }, + { + args: []string{"--config", "no file name"}, + expected: []string{ + "from INI", + "from INI", + }, + }, + { + args: []string{"--value", "from CLI before", "--default", "from CLI before", "--config", "no file name"}, + expected: []string{ + "from INI", + "from INI", + }, + }, + { + args: []string{"--value", "from CLI before", "--default", "from CLI before", "--config", "no file name", "--value", "from CLI after", "--default", "from CLI after"}, + expected: []string{ + "from CLI after", + "from CLI after", + }, + }, + } + + for _, test := range tests { + var opts struct { + Config func(s string) error `long:"config" no-ini:"true"` + Value string `long:"value" required:"true"` + Default string `long:"default" required:"true" default:"from default"` + } + + p := NewParser(&opts, Default) + + opts.Config = func(s string) error { + ini := NewIniParser(p) + + return ini.Parse(bytes.NewBufferString("value = from INI\ndefault = from INI")) + } + + _, err := p.ParseArgs(test.args) + if err != nil { + t.Fatalf("Unexpected error %s with args %+v", err, test.args) + } + + if opts.Value != test.expected[0] { + t.Fatalf("Expected Value to be \"%s\" but was \"%s\" with args %+v", test.expected[0], opts.Value, test.args) + } + + if opts.Default != test.expected[1] { + t.Fatalf("Expected Default to be \"%s\" but was \"%s\" with args %+v", test.expected[1], opts.Default, test.args) + } + } +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/long_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/long_test.go new file mode 100644 index 000000000..02fc8c701 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/long_test.go @@ -0,0 +1,85 @@ +package flags + +import ( + "testing" +) + +func TestLong(t *testing.T) { + var opts = struct { + Value bool `long:"value"` + }{} + + ret := assertParseSuccess(t, &opts, "--value") + + assertStringArray(t, ret, []string{}) + + if !opts.Value { + t.Errorf("Expected Value to be true") + } +} + +func TestLongArg(t *testing.T) { + var opts = struct { + Value string `long:"value"` + }{} + + ret := assertParseSuccess(t, &opts, "--value", "value") + + assertStringArray(t, ret, []string{}) + assertString(t, opts.Value, "value") +} + +func TestLongArgEqual(t *testing.T) { + var opts = struct { + Value string `long:"value"` + }{} + + ret := assertParseSuccess(t, &opts, "--value=value") + + assertStringArray(t, ret, []string{}) + assertString(t, opts.Value, "value") +} + +func TestLongDefault(t *testing.T) { + var opts = struct { + Value string `long:"value" default:"value"` + }{} + + ret := assertParseSuccess(t, &opts) + + assertStringArray(t, ret, []string{}) + assertString(t, opts.Value, "value") +} + +func TestLongOptional(t *testing.T) { + var opts = struct { + Value string `long:"value" optional:"yes" optional-value:"value"` + }{} + + ret := assertParseSuccess(t, &opts, "--value") + + assertStringArray(t, ret, []string{}) + assertString(t, opts.Value, "value") +} + +func TestLongOptionalArg(t *testing.T) { + var opts = struct { + Value string `long:"value" optional:"yes" optional-value:"value"` + }{} + + ret := assertParseSuccess(t, &opts, "--value", "no") + + assertStringArray(t, ret, []string{"no"}) + assertString(t, opts.Value, "value") +} + +func TestLongOptionalArgEqual(t *testing.T) { + var opts = struct { + Value string `long:"value" optional:"yes" optional-value:"value"` + }{} + + ret := assertParseSuccess(t, &opts, "--value=value", "no") + + assertStringArray(t, ret, []string{"no"}) + assertString(t, opts.Value, "value") +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/man.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/man.go new file mode 100644 index 000000000..e8e5916c0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/man.go @@ -0,0 +1,158 @@ +package flags + +import ( + "fmt" + "io" + "strings" + "time" +) + +func formatForMan(wr io.Writer, s string) { + for { + idx := strings.IndexRune(s, '`') + + if idx < 0 { + fmt.Fprintf(wr, "%s", s) + break + } + + fmt.Fprintf(wr, "%s", s[:idx]) + + s = s[idx+1:] + idx = strings.IndexRune(s, '\'') + + if idx < 0 { + fmt.Fprintf(wr, "%s", s) + break + } + + fmt.Fprintf(wr, "\\fB%s\\fP", s[:idx]) + s = s[idx+1:] + } +} + +func writeManPageOptions(wr io.Writer, grp *Group) { + grp.eachGroup(func(group *Group) { + for _, opt := range group.options { + if !opt.canCli() { + continue + } + + fmt.Fprintln(wr, ".TP") + fmt.Fprintf(wr, "\\fB") + + if opt.ShortName != 0 { + fmt.Fprintf(wr, "-%c", opt.ShortName) + } + + if len(opt.LongName) != 0 { + if opt.ShortName != 0 { + fmt.Fprintf(wr, ", ") + } + + fmt.Fprintf(wr, "--%s", opt.LongNameWithNamespace()) + } + + fmt.Fprintln(wr, "\\fP") + if len(opt.Description) != 0 { + formatForMan(wr, opt.Description) + fmt.Fprintln(wr, "") + } + } + }) +} + +func writeManPageSubcommands(wr io.Writer, name string, root *Command) { + commands := root.sortedCommands() + + for _, c := range commands { + var nn string + + if len(name) != 0 { + nn = name + " " + c.Name + } else { + nn = c.Name + } + + writeManPageCommand(wr, nn, root, c) + } +} + +func writeManPageCommand(wr io.Writer, name string, root *Command, command *Command) { + fmt.Fprintf(wr, ".SS %s\n", name) + fmt.Fprintln(wr, command.ShortDescription) + + if len(command.LongDescription) > 0 { + fmt.Fprintln(wr, "") + + cmdstart := fmt.Sprintf("The %s command", command.Name) + + if strings.HasPrefix(command.LongDescription, cmdstart) { + fmt.Fprintf(wr, "The \\fI%s\\fP command", command.Name) + + formatForMan(wr, command.LongDescription[len(cmdstart):]) + fmt.Fprintln(wr, "") + } else { + formatForMan(wr, command.LongDescription) + fmt.Fprintln(wr, "") + } + } + + var usage string + if us, ok := command.data.(Usage); ok { + usage = us.Usage() + } else if command.hasCliOptions() { + usage = fmt.Sprintf("[%s-OPTIONS]", command.Name) + } + + var pre string + if root.hasCliOptions() { + pre = fmt.Sprintf("%s [OPTIONS] %s", root.Name, command.Name) + } else { + pre = fmt.Sprintf("%s %s", root.Name, command.Name) + } + + if len(usage) > 0 { + fmt.Fprintf(wr, "\n\\fBUsage\\fP: %s %s\n\n", pre, usage) + } + + if len(command.Aliases) > 0 { + fmt.Fprintf(wr, "\n\\fBAliases\\fP: %s\n\n", strings.Join(command.Aliases, ", ")) + } + + writeManPageOptions(wr, command.Group) + writeManPageSubcommands(wr, name, command) +} + +// WriteManPage writes a basic man page in groff format to the specified +// writer. +func (p *Parser) WriteManPage(wr io.Writer) { + t := time.Now() + + fmt.Fprintf(wr, ".TH %s 1 \"%s\"\n", p.Name, t.Format("2 January 2006")) + fmt.Fprintln(wr, ".SH NAME") + fmt.Fprintf(wr, "%s \\- %s\n", p.Name, p.ShortDescription) + fmt.Fprintln(wr, ".SH SYNOPSIS") + + usage := p.Usage + + if len(usage) == 0 { + usage = "[OPTIONS]" + } + + fmt.Fprintf(wr, "\\fB%s\\fP %s\n", p.Name, usage) + fmt.Fprintln(wr, ".SH DESCRIPTION") + + formatForMan(wr, p.LongDescription) + fmt.Fprintln(wr, "") + + fmt.Fprintln(wr, ".SH OPTIONS") + + writeManPageOptions(wr, p.Command.Group) + + if len(p.commands) > 0 { + fmt.Fprintln(wr, ".SH COMMANDS") + + writeManPageSubcommands(wr, "", p.Command) + } +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/marshal_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/marshal_test.go new file mode 100644 index 000000000..59c9ccefb --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/marshal_test.go @@ -0,0 +1,97 @@ +package flags + +import ( + "fmt" + "testing" +) + +type marshalled bool + +func (m *marshalled) UnmarshalFlag(value string) error { + if value == "yes" { + *m = true + } else if value == "no" { + *m = false + } else { + return fmt.Errorf("`%s' is not a valid value, please specify `yes' or `no'", value) + } + + return nil +} + +func (m marshalled) MarshalFlag() (string, error) { + if m { + return "yes", nil + } + + return "no", nil +} + +type marshalledError bool + +func (m marshalledError) MarshalFlag() (string, error) { + return "", newErrorf(ErrMarshal, "Failed to marshal") +} + +func TestUnmarshal(t *testing.T) { + var opts = struct { + Value marshalled `short:"v"` + }{} + + ret := assertParseSuccess(t, &opts, "-v=yes") + + assertStringArray(t, ret, []string{}) + + if !opts.Value { + t.Errorf("Expected Value to be true") + } +} + +func TestUnmarshalDefault(t *testing.T) { + var opts = struct { + Value marshalled `short:"v" default:"yes"` + }{} + + ret := assertParseSuccess(t, &opts) + + assertStringArray(t, ret, []string{}) + + if !opts.Value { + t.Errorf("Expected Value to be true") + } +} + +func TestUnmarshalOptional(t *testing.T) { + var opts = struct { + Value marshalled `short:"v" optional:"yes" optional-value:"yes"` + }{} + + ret := assertParseSuccess(t, &opts, "-v") + + assertStringArray(t, ret, []string{}) + + if !opts.Value { + t.Errorf("Expected Value to be true") + } +} + +func TestUnmarshalError(t *testing.T) { + var opts = struct { + Value marshalled `short:"v"` + }{} + + assertParseFail(t, ErrMarshal, fmt.Sprintf("invalid argument for flag `%cv' (expected flags.marshalled): `invalid' is not a valid value, please specify `yes' or `no'", defaultShortOptDelimiter), &opts, "-vinvalid") +} + +func TestMarshalError(t *testing.T) { + var opts = struct { + Value marshalledError `short:"v"` + }{} + + p := NewParser(&opts, Default) + o := p.Command.Groups()[0].Options()[0] + + _, err := convertToString(o.value, o.tag) + + assertError(t, err, ErrMarshal, "Failed to marshal") +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/multitag.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/multitag.go new file mode 100644 index 000000000..96bb1a31d --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/multitag.go @@ -0,0 +1,140 @@ +package flags + +import ( + "strconv" +) + +type multiTag struct { + value string + cache map[string][]string +} + +func newMultiTag(v string) multiTag { + return multiTag{ + value: v, + } +} + +func (x *multiTag) scan() (map[string][]string, error) { + v := x.value + + ret := make(map[string][]string) + + // This is mostly copied from reflect.StructTag.Get + for v != "" { + i := 0 + + // Skip whitespace + for i < len(v) && v[i] == ' ' { + i++ + } + + v = v[i:] + + if v == "" { + break + } + + // Scan to colon to find key + i = 0 + + for i < len(v) && v[i] != ' ' && v[i] != ':' && v[i] != '"' { + i++ + } + + if i >= len(v) { + return nil, newErrorf(ErrTag, "expected `:' after key name, but got end of tag (in `%v`)", x.value) + } + + if v[i] != ':' { + return nil, newErrorf(ErrTag, "expected `:' after key name, but got `%v' (in `%v`)", v[i], x.value) + } + + if i+1 >= len(v) { + return nil, newErrorf(ErrTag, "expected `\"' to start tag value at end of tag (in `%v`)", x.value) + } + + if v[i+1] != '"' { + return nil, newErrorf(ErrTag, "expected `\"' to start tag value, but got `%v' (in `%v`)", v[i+1], x.value) + } + + name := v[:i] + v = v[i+1:] + + // Scan quoted string to find value + i = 1 + + for i < len(v) && v[i] != '"' { + if v[i] == '\n' { + return nil, newErrorf(ErrTag, "unexpected newline in tag value `%v' (in `%v`)", name, x.value) + } + + if v[i] == '\\' { + i++ + } + i++ + } + + if i >= len(v) { + return nil, newErrorf(ErrTag, "expected end of tag value `\"' at end of tag (in `%v`)", x.value) + } + + val, err := strconv.Unquote(v[:i+1]) + + if err != nil { + return nil, newErrorf(ErrTag, "Malformed value of tag `%v:%v` => %v (in `%v`)", name, v[:i+1], err, x.value) + } + + v = v[i+1:] + + ret[name] = append(ret[name], val) + } + + return ret, nil +} + +func (x *multiTag) Parse() error { + vals, err := x.scan() + x.cache = vals + + return err +} + +func (x *multiTag) cached() map[string][]string { + if x.cache == nil { + cache, _ := x.scan() + + if cache == nil { + cache = make(map[string][]string) + } + + x.cache = cache + } + + return x.cache +} + +func (x *multiTag) Get(key string) string { + c := x.cached() + + if v, ok := c[key]; ok { + return v[len(v)-1] + } + + return "" +} + +func (x *multiTag) GetMany(key string) []string { + c := x.cached() + return c[key] +} + +func (x *multiTag) Set(key string, value string) { + c := x.cached() + c[key] = []string{value} +} + +func (x *multiTag) SetMany(key string, value []string) { + c := x.cached() + c[key] = value +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/option.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/option.go new file mode 100644 index 000000000..29e702c19 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/option.go @@ -0,0 +1,157 @@ +package flags + +import ( + "fmt" + "reflect" + "unicode/utf8" +) + +// Option flag information. Contains a description of the option, short and +// long name as well as a default value and whether an argument for this +// flag is optional. +type Option struct { + // The description of the option flag. This description is shown + // automatically in the built-in help. + Description string + + // The short name of the option (a single character). If not 0, the + // option flag can be 'activated' using -. Either ShortName + // or LongName needs to be non-empty. + ShortName rune + + // The long name of the option. If not "", the option flag can be + // activated using --. Either ShortName or LongName needs + // to be non-empty. + LongName string + + // The default value of the option. + Default []string + + // The optional environment default value key name. + EnvDefaultKey string + + // The optional delimiter string for EnvDefaultKey values. + EnvDefaultDelim string + + // If true, specifies that the argument to an option flag is optional. + // When no argument to the flag is specified on the command line, the + // value of Default will be set in the field this option represents. + // This is only valid for non-boolean options. + OptionalArgument bool + + // The optional value of the option. The optional value is used when + // the option flag is marked as having an OptionalArgument. This means + // that when the flag is specified, but no option argument is given, + // the value of the field this option represents will be set to + // OptionalValue. This is only valid for non-boolean options. + OptionalValue []string + + // If true, the option _must_ be specified on the command line. If the + // option is not specified, the parser will generate an ErrRequired type + // error. + Required bool + + // A name for the value of an option shown in the Help as --flag [ValueName] + ValueName string + + // A mask value to show in the help instead of the default value. This + // is useful for hiding sensitive information in the help, such as + // passwords. + DefaultMask string + + // The group which the option belongs to + group *Group + + // The struct field which the option represents. + field reflect.StructField + + // The struct field value which the option represents. + value reflect.Value + + // Determines if the option will be always quoted in the INI output + iniQuote bool + + tag multiTag + isSet bool +} + +// LongNameWithNamespace returns the option's long name with the group namespaces +// prepended by walking up the option's group tree. Namespaces and the long name +// itself are separated by the parser's namespace delimiter. If the long name is +// empty an empty string is returned. +func (option *Option) LongNameWithNamespace() string { + if len(option.LongName) == 0 { + return "" + } + + // fetch the namespace delimiter from the parser which is always at the + // end of the group hierarchy + namespaceDelimiter := "" + g := option.group + + for { + if p, ok := g.parent.(*Parser); ok { + namespaceDelimiter = p.NamespaceDelimiter + + break + } + + switch i := g.parent.(type) { + case *Command: + g = i.Group + case *Group: + g = i + } + } + + // concatenate long name with namespace + longName := option.LongName + g = option.group + + for g != nil { + if g.Namespace != "" { + longName = g.Namespace + namespaceDelimiter + longName + } + + switch i := g.parent.(type) { + case *Command: + g = i.Group + case *Group: + g = i + case *Parser: + g = nil + } + } + + return longName +} + +// String converts an option to a human friendly readable string describing the +// option. +func (option *Option) String() string { + var s string + var short string + + if option.ShortName != 0 { + data := make([]byte, utf8.RuneLen(option.ShortName)) + utf8.EncodeRune(data, option.ShortName) + short = string(data) + + if len(option.LongName) != 0 { + s = fmt.Sprintf("%s%s, %s%s", + string(defaultShortOptDelimiter), short, + defaultLongOptDelimiter, option.LongNameWithNamespace()) + } else { + s = fmt.Sprintf("%s%s", string(defaultShortOptDelimiter), short) + } + } else if len(option.LongName) != 0 { + s = fmt.Sprintf("%s%s", defaultLongOptDelimiter, option.LongNameWithNamespace()) + } + + return s +} + +// Value returns the option value as an interface{}. +func (option *Option) Value() interface{} { + return option.value.Interface() +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/option_private.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/option_private.go new file mode 100644 index 000000000..d36c84117 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/option_private.go @@ -0,0 +1,182 @@ +package flags + +import ( + "reflect" + "strings" + "syscall" +) + +// Set the value of an option to the specified value. An error will be returned +// if the specified value could not be converted to the corresponding option +// value type. +func (option *Option) set(value *string) error { + option.isSet = true + + if option.isFunc() { + return option.call(value) + } else if value != nil { + return convert(*value, option.value, option.tag) + } + + return convert("", option.value, option.tag) +} + +func (option *Option) canCli() bool { + return option.ShortName != 0 || len(option.LongName) != 0 +} + +func (option *Option) canArgument() bool { + if u := option.isUnmarshaler(); u != nil { + return true + } + + return !option.isBool() +} + +func (option *Option) emptyValue() reflect.Value { + tp := option.value.Type() + + if tp.Kind() == reflect.Map { + return reflect.MakeMap(tp) + } + + return reflect.Zero(tp) +} + +func (option *Option) empty() { + if !option.isFunc() { + option.value.Set(option.emptyValue()) + } +} + +func (option *Option) clearDefault() { + usedDefault := option.Default + if envKey := option.EnvDefaultKey; envKey != "" { + // os.Getenv() makes no distinction between undefined and + // empty values, so we use syscall.Getenv() + if value, ok := syscall.Getenv(envKey); ok { + if option.EnvDefaultDelim != "" { + usedDefault = strings.Split(value, + option.EnvDefaultDelim) + } else { + usedDefault = []string{value} + } + } + } + + if len(usedDefault) > 0 { + option.empty() + + for _, d := range usedDefault { + option.set(&d) + } + } else { + tp := option.value.Type() + + switch tp.Kind() { + case reflect.Map: + if option.value.IsNil() { + option.empty() + } + case reflect.Slice: + if option.value.IsNil() { + option.empty() + } + } + } +} + +func (option *Option) valueIsDefault() bool { + // Check if the value of the option corresponds to its + // default value + emptyval := option.emptyValue() + + checkvalptr := reflect.New(emptyval.Type()) + checkval := reflect.Indirect(checkvalptr) + + checkval.Set(emptyval) + + if len(option.Default) != 0 { + for _, v := range option.Default { + convert(v, checkval, option.tag) + } + } + + return reflect.DeepEqual(option.value.Interface(), checkval.Interface()) +} + +func (option *Option) isUnmarshaler() Unmarshaler { + v := option.value + + for { + if !v.CanInterface() { + break + } + + i := v.Interface() + + if u, ok := i.(Unmarshaler); ok { + return u + } + + if !v.CanAddr() { + break + } + + v = v.Addr() + } + + return nil +} + +func (option *Option) isBool() bool { + tp := option.value.Type() + + for { + switch tp.Kind() { + case reflect.Bool: + return true + case reflect.Slice: + return (tp.Elem().Kind() == reflect.Bool) + case reflect.Func: + return tp.NumIn() == 0 + case reflect.Ptr: + tp = tp.Elem() + default: + return false + } + } +} + +func (option *Option) isFunc() bool { + return option.value.Type().Kind() == reflect.Func +} + +func (option *Option) call(value *string) error { + var retval []reflect.Value + + if value == nil { + retval = option.value.Call(nil) + } else { + tp := option.value.Type().In(0) + + val := reflect.New(tp) + val = reflect.Indirect(val) + + if err := convert(*value, val, option.tag); err != nil { + return err + } + + retval = option.value.Call([]reflect.Value{val}) + } + + if len(retval) == 1 && retval[0].Type() == reflect.TypeOf((*error)(nil)).Elem() { + if retval[0].Interface() == nil { + return nil + } + + return retval[0].Interface().(error) + } + + return nil +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/options_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/options_test.go new file mode 100644 index 000000000..b0fe9f456 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/options_test.go @@ -0,0 +1,45 @@ +package flags + +import ( + "testing" +) + +func TestPassDoubleDash(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + }{} + + p := NewParser(&opts, PassDoubleDash) + ret, err := p.ParseArgs([]string{"-v", "--", "-v", "-g"}) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + return + } + + if !opts.Value { + t.Errorf("Expected Value to be true") + } + + assertStringArray(t, ret, []string{"-v", "-g"}) +} + +func TestPassAfterNonOption(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + }{} + + p := NewParser(&opts, PassAfterNonOption) + ret, err := p.ParseArgs([]string{"-v", "arg", "-v", "-g"}) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + return + } + + if !opts.Value { + t.Errorf("Expected Value to be true") + } + + assertStringArray(t, ret, []string{"arg", "-v", "-g"}) +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/optstyle_other.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/optstyle_other.go new file mode 100644 index 000000000..29ca4b606 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/optstyle_other.go @@ -0,0 +1,67 @@ +// +build !windows + +package flags + +import ( + "strings" +) + +const ( + defaultShortOptDelimiter = '-' + defaultLongOptDelimiter = "--" + defaultNameArgDelimiter = '=' +) + +func argumentStartsOption(arg string) bool { + return len(arg) > 0 && arg[0] == '-' +} + +func argumentIsOption(arg string) bool { + if len(arg) > 1 && arg[0] == '-' && arg[1] != '-' { + return true + } + + if len(arg) > 2 && arg[0] == '-' && arg[1] == '-' && arg[2] != '-' { + return true + } + + return false +} + +// stripOptionPrefix returns the option without the prefix and whether or +// not the option is a long option or not. +func stripOptionPrefix(optname string) (prefix string, name string, islong bool) { + if strings.HasPrefix(optname, "--") { + return "--", optname[2:], true + } else if strings.HasPrefix(optname, "-") { + return "-", optname[1:], false + } + + return "", optname, false +} + +// splitOption attempts to split the passed option into a name and an argument. +// When there is no argument specified, nil will be returned for it. +func splitOption(prefix string, option string, islong bool) (string, string, *string) { + pos := strings.Index(option, "=") + + if (islong && pos >= 0) || (!islong && pos == 1) { + rest := option[pos+1:] + return option[:pos], "=", &rest + } + + return option, "", nil +} + +// addHelpGroup adds a new group that contains default help parameters. +func (c *Command) addHelpGroup(showHelp func() error) *Group { + var help struct { + ShowHelp func() error `short:"h" long:"help" description:"Show this help message"` + } + + help.ShowHelp = showHelp + ret, _ := c.AddGroup("Help Options", "", &help) + ret.isBuiltinHelp = true + + return ret +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/optstyle_windows.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/optstyle_windows.go new file mode 100644 index 000000000..a51de9cb2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/optstyle_windows.go @@ -0,0 +1,106 @@ +package flags + +import ( + "strings" +) + +// Windows uses a front slash for both short and long options. Also it uses +// a colon for name/argument delimter. +const ( + defaultShortOptDelimiter = '/' + defaultLongOptDelimiter = "/" + defaultNameArgDelimiter = ':' +) + +func argumentStartsOption(arg string) bool { + return len(arg) > 0 && (arg[0] == '-' || arg[0] == '/') +} + +func argumentIsOption(arg string) bool { + // Windows-style options allow front slash for the option + // delimiter. + if len(arg) > 1 && arg[0] == '/' { + return true + } + + if len(arg) > 1 && arg[0] == '-' && arg[1] != '-' { + return true + } + + if len(arg) > 2 && arg[0] == '-' && arg[1] == '-' && arg[2] != '-' { + return true + } + + return false +} + +// stripOptionPrefix returns the option without the prefix and whether or +// not the option is a long option or not. +func stripOptionPrefix(optname string) (prefix string, name string, islong bool) { + // Determine if the argument is a long option or not. Windows + // typically supports both long and short options with a single + // front slash as the option delimiter, so handle this situation + // nicely. + possplit := 0 + + if strings.HasPrefix(optname, "--") { + possplit = 2 + islong = true + } else if strings.HasPrefix(optname, "-") { + possplit = 1 + islong = false + } else if strings.HasPrefix(optname, "/") { + possplit = 1 + islong = len(optname) > 2 + } + + return optname[:possplit], optname[possplit:], islong +} + +// splitOption attempts to split the passed option into a name and an argument. +// When there is no argument specified, nil will be returned for it. +func splitOption(prefix string, option string, islong bool) (string, string, *string) { + if len(option) == 0 { + return option, "", nil + } + + // Windows typically uses a colon for the option name and argument + // delimiter while POSIX typically uses an equals. Support both styles, + // but don't allow the two to be mixed. That is to say /foo:bar and + // --foo=bar are acceptable, but /foo=bar and --foo:bar are not. + var pos int + var sp string + + if prefix == "/" { + sp = ":" + pos = strings.Index(option, sp) + } else if len(prefix) > 0 { + sp = "=" + pos = strings.Index(option, sp) + } + + if (islong && pos >= 0) || (!islong && pos == 1) { + rest := option[pos+1:] + return option[:pos], sp, &rest + } + + return option, "", nil +} + +// addHelpGroup adds a new group that contains default help parameters. +func (c *Command) addHelpGroup(showHelp func() error) *Group { + // Windows CLI applications typically use /? for help, so make both + // that available as well as the POSIX style h and help. + var help struct { + ShowHelpWindows func() error `short:"?" description:"Show this help message"` + ShowHelpPosix func() error `short:"h" long:"help" description:"Show this help message"` + } + + help.ShowHelpWindows = showHelp + help.ShowHelpPosix = showHelp + + ret, _ := c.AddGroup("Help Options", "", &help) + ret.isBuiltinHelp = true + + return ret +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/parser.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/parser.go new file mode 100644 index 000000000..6f45a0396 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/parser.go @@ -0,0 +1,286 @@ +// Copyright 2012 Jesse van den Kieboom. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flags + +import ( + "os" + "path" +) + +// A Parser provides command line option parsing. It can contain several +// option groups each with their own set of options. +type Parser struct { + // Embedded, see Command for more information + *Command + + // A usage string to be displayed in the help message. + Usage string + + // Option flags changing the behavior of the parser. + Options Options + + // NamespaceDelimiter separates group namespaces and option long names + NamespaceDelimiter string + + // UnknownOptionsHandler is a function which gets called when the parser + // encounters an unknown option. The function receives the unknown option + // name, a SplitArgument which specifies its value if set with an argument + // separator, and the remaining command line arguments. + // It should return a new list of remaining arguments to continue parsing, + // or an error to indicate a parse failure. + UnknownOptionHandler func(option string, arg SplitArgument, args []string) ([]string, error) + + internalError error +} + +// SplitArgument represents the argument value of an option that was passed using +// an argument separator. +type SplitArgument interface { + // String returns the option's value as a string, and a boolean indicating + // if the option was present. + Value() (string, bool) +} + +type strArgument struct { + value *string +} + +func (s strArgument) Value() (string, bool) { + if s.value == nil { + return "", false + } + + return *s.value, true +} + +// Options provides parser options that change the behavior of the option +// parser. +type Options uint + +const ( + // None indicates no options. + None Options = 0 + + // HelpFlag adds a default Help Options group to the parser containing + // -h and --help options. When either -h or --help is specified on the + // command line, the parser will return the special error of type + // ErrHelp. When PrintErrors is also specified, then the help message + // will also be automatically printed to os.Stderr. + HelpFlag = 1 << iota + + // PassDoubleDash passes all arguments after a double dash, --, as + // remaining command line arguments (i.e. they will not be parsed for + // flags). + PassDoubleDash + + // IgnoreUnknown ignores any unknown options and passes them as + // remaining command line arguments instead of generating an error. + IgnoreUnknown + + // PrintErrors prints any errors which occurred during parsing to + // os.Stderr. + PrintErrors + + // PassAfterNonOption passes all arguments after the first non option + // as remaining command line arguments. This is equivalent to strict + // POSIX processing. + PassAfterNonOption + + // Default is a convenient default set of options which should cover + // most of the uses of the flags package. + Default = HelpFlag | PrintErrors | PassDoubleDash +) + +// Parse is a convenience function to parse command line options with default +// settings. The provided data is a pointer to a struct representing the +// default option group (named "Application Options"). For more control, use +// flags.NewParser. +func Parse(data interface{}) ([]string, error) { + return NewParser(data, Default).Parse() +} + +// ParseArgs is a convenience function to parse command line options with default +// settings. The provided data is a pointer to a struct representing the +// default option group (named "Application Options"). The args argument is +// the list of command line arguments to parse. If you just want to parse the +// default program command line arguments (i.e. os.Args), then use flags.Parse +// instead. For more control, use flags.NewParser. +func ParseArgs(data interface{}, args []string) ([]string, error) { + return NewParser(data, Default).ParseArgs(args) +} + +// NewParser creates a new parser. It uses os.Args[0] as the application +// name and then calls Parser.NewNamedParser (see Parser.NewNamedParser for +// more details). The provided data is a pointer to a struct representing the +// default option group (named "Application Options"), or nil if the default +// group should not be added. The options parameter specifies a set of options +// for the parser. +func NewParser(data interface{}, options Options) *Parser { + p := NewNamedParser(path.Base(os.Args[0]), options) + + if data != nil { + g, err := p.AddGroup("Application Options", "", data) + + if err == nil { + g.parent = p + } + + p.internalError = err + } + + return p +} + +// NewNamedParser creates a new parser. The appname is used to display the +// executable name in the built-in help message. Option groups and commands can +// be added to this parser by using AddGroup and AddCommand. +func NewNamedParser(appname string, options Options) *Parser { + p := &Parser{ + Command: newCommand(appname, "", "", nil), + Options: options, + NamespaceDelimiter: ".", + } + + p.Command.parent = p + + return p +} + +// Parse parses the command line arguments from os.Args using Parser.ParseArgs. +// For more detailed information see ParseArgs. +func (p *Parser) Parse() ([]string, error) { + return p.ParseArgs(os.Args[1:]) +} + +// ParseArgs parses the command line arguments according to the option groups that +// were added to the parser. On successful parsing of the arguments, the +// remaining, non-option, arguments (if any) are returned. The returned error +// indicates a parsing error and can be used with PrintError to display +// contextual information on where the error occurred exactly. +// +// When the common help group has been added (AddHelp) and either -h or --help +// was specified in the command line arguments, a help message will be +// automatically printed. Furthermore, the special error type ErrHelp is returned. +// It is up to the caller to exit the program if so desired. +func (p *Parser) ParseArgs(args []string) ([]string, error) { + if p.internalError != nil { + return nil, p.internalError + } + + p.clearIsSet() + + // Add built-in help group to all commands if necessary + if (p.Options & HelpFlag) != None { + p.addHelpGroups(p.showBuiltinHelp) + } + + compval := os.Getenv("GO_FLAGS_COMPLETION") + + if len(compval) != 0 { + comp := &completion{parser: p} + + if compval == "verbose" { + comp.ShowDescriptions = true + } + + comp.execute(args) + + return nil, nil + } + + s := &parseState{ + args: args, + retargs: make([]string, 0, len(args)), + } + + p.fillParseState(s) + + for !s.eof() { + arg := s.pop() + + // When PassDoubleDash is set and we encounter a --, then + // simply append all the rest as arguments and break out + if (p.Options&PassDoubleDash) != None && arg == "--" { + s.addArgs(s.args...) + break + } + + if !argumentIsOption(arg) { + // Note: this also sets s.err, so we can just check for + // nil here and use s.err later + if p.parseNonOption(s) != nil { + break + } + + continue + } + + var err error + + prefix, optname, islong := stripOptionPrefix(arg) + optname, _, argument := splitOption(prefix, optname, islong) + + if islong { + err = p.parseLong(s, optname, argument) + } else { + err = p.parseShort(s, optname, argument) + } + + if err != nil { + ignoreUnknown := (p.Options & IgnoreUnknown) != None + parseErr := wrapError(err) + + if parseErr.Type != ErrUnknownFlag || (!ignoreUnknown && p.UnknownOptionHandler == nil) { + s.err = parseErr + break + } + + if ignoreUnknown { + s.addArgs(arg) + } else if p.UnknownOptionHandler != nil { + modifiedArgs, err := p.UnknownOptionHandler(optname, strArgument{argument}, s.args) + + if err != nil { + s.err = err + break + } + + s.args = modifiedArgs + } + } + } + + if s.err == nil { + p.eachCommand(func(c *Command) { + c.eachGroup(func(g *Group) { + for _, option := range g.options { + if option.isSet { + continue + } + + option.clearDefault() + } + }) + }, true) + + s.checkRequired(p) + } + + var reterr error + + if s.err != nil { + reterr = s.err + } else if len(s.command.commands) != 0 && !s.command.SubcommandsOptional { + reterr = s.estimateCommand() + } else if cmd, ok := s.command.data.(Commander); ok { + reterr = cmd.Execute(s.retargs) + } + + if reterr != nil { + return append([]string{s.arg}, s.args...), p.printError(reterr) + } + + return s.retargs, nil +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/parser_private.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/parser_private.go new file mode 100644 index 000000000..76be4a75f --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/parser_private.go @@ -0,0 +1,340 @@ +package flags + +import ( + "bytes" + "fmt" + "os" + "sort" + "strings" + "unicode/utf8" +) + +type parseState struct { + arg string + args []string + retargs []string + positional []*Arg + err error + + command *Command + lookup lookup +} + +func (p *parseState) eof() bool { + return len(p.args) == 0 +} + +func (p *parseState) pop() string { + if p.eof() { + return "" + } + + p.arg = p.args[0] + p.args = p.args[1:] + + return p.arg +} + +func (p *parseState) peek() string { + if p.eof() { + return "" + } + + return p.args[0] +} + +func (p *parseState) checkRequired(parser *Parser) error { + c := parser.Command + + var required []*Option + + for c != nil { + c.eachGroup(func(g *Group) { + for _, option := range g.options { + if !option.isSet && option.Required { + required = append(required, option) + } + } + }) + + c = c.Active + } + + if len(required) == 0 { + if len(p.positional) > 0 && p.command.ArgsRequired { + var reqnames []string + + for _, arg := range p.positional { + if arg.isRemaining() { + break + } + + reqnames = append(reqnames, "`"+arg.Name+"`") + } + + if len(reqnames) == 0 { + return nil + } + + var msg string + + if len(reqnames) == 1 { + msg = fmt.Sprintf("the required argument %s was not provided", reqnames[0]) + } else { + msg = fmt.Sprintf("the required arguments %s and %s were not provided", + strings.Join(reqnames[:len(reqnames)-1], ", "), reqnames[len(reqnames)-1]) + } + + p.err = newError(ErrRequired, msg) + return p.err + } + + return nil + } + + names := make([]string, 0, len(required)) + + for _, k := range required { + names = append(names, "`"+k.String()+"'") + } + + sort.Strings(names) + + var msg string + + if len(names) == 1 { + msg = fmt.Sprintf("the required flag %s was not specified", names[0]) + } else { + msg = fmt.Sprintf("the required flags %s and %s were not specified", + strings.Join(names[:len(names)-1], ", "), names[len(names)-1]) + } + + p.err = newError(ErrRequired, msg) + return p.err +} + +func (p *parseState) estimateCommand() error { + commands := p.command.sortedCommands() + cmdnames := make([]string, len(commands)) + + for i, v := range commands { + cmdnames[i] = v.Name + } + + var msg string + var errtype ErrorType + + if len(p.retargs) != 0 { + c, l := closestChoice(p.retargs[0], cmdnames) + msg = fmt.Sprintf("Unknown command `%s'", p.retargs[0]) + errtype = ErrUnknownCommand + + if float32(l)/float32(len(c)) < 0.5 { + msg = fmt.Sprintf("%s, did you mean `%s'?", msg, c) + } else if len(cmdnames) == 1 { + msg = fmt.Sprintf("%s. You should use the %s command", + msg, + cmdnames[0]) + } else { + msg = fmt.Sprintf("%s. Please specify one command of: %s or %s", + msg, + strings.Join(cmdnames[:len(cmdnames)-1], ", "), + cmdnames[len(cmdnames)-1]) + } + } else { + errtype = ErrCommandRequired + + if len(cmdnames) == 1 { + msg = fmt.Sprintf("Please specify the %s command", cmdnames[0]) + } else { + msg = fmt.Sprintf("Please specify one command of: %s or %s", + strings.Join(cmdnames[:len(cmdnames)-1], ", "), + cmdnames[len(cmdnames)-1]) + } + } + + return newError(errtype, msg) +} + +func (p *Parser) parseOption(s *parseState, name string, option *Option, canarg bool, argument *string) (err error) { + if !option.canArgument() { + if argument != nil { + return newErrorf(ErrNoArgumentForBool, "bool flag `%s' cannot have an argument", option) + } + + err = option.set(nil) + } else if argument != nil || (canarg && !s.eof()) { + var arg string + + if argument != nil { + arg = *argument + } else { + arg = s.pop() + + if argumentIsOption(arg) { + return newErrorf(ErrExpectedArgument, "expected argument for flag `%s', but got option `%s'", option, arg) + } else if p.Options&PassDoubleDash != 0 && arg == "--" { + return newErrorf(ErrExpectedArgument, "expected argument for flag `%s', but got double dash `--'", option) + } + } + + if option.tag.Get("unquote") != "false" { + arg, err = unquoteIfPossible(arg) + } + + if err == nil { + err = option.set(&arg) + } + } else if option.OptionalArgument { + option.empty() + + for _, v := range option.OptionalValue { + err = option.set(&v) + + if err != nil { + break + } + } + } else { + err = newErrorf(ErrExpectedArgument, "expected argument for flag `%s'", option) + } + + if err != nil { + if _, ok := err.(*Error); !ok { + err = newErrorf(ErrMarshal, "invalid argument for flag `%s' (expected %s): %s", + option, + option.value.Type(), + err.Error()) + } + } + + return err +} + +func (p *Parser) parseLong(s *parseState, name string, argument *string) error { + if option := s.lookup.longNames[name]; option != nil { + // Only long options that are required can consume an argument + // from the argument list + canarg := !option.OptionalArgument + + return p.parseOption(s, name, option, canarg, argument) + } + + return newErrorf(ErrUnknownFlag, "unknown flag `%s'", name) +} + +func (p *Parser) splitShortConcatArg(s *parseState, optname string) (string, *string) { + c, n := utf8.DecodeRuneInString(optname) + + if n == len(optname) { + return optname, nil + } + + first := string(c) + + if option := s.lookup.shortNames[first]; option != nil && option.canArgument() { + arg := optname[n:] + return first, &arg + } + + return optname, nil +} + +func (p *Parser) parseShort(s *parseState, optname string, argument *string) error { + if argument == nil { + optname, argument = p.splitShortConcatArg(s, optname) + } + + for i, c := range optname { + shortname := string(c) + + if option := s.lookup.shortNames[shortname]; option != nil { + // Only the last short argument can consume an argument from + // the arguments list, and only if it's non optional + canarg := (i+utf8.RuneLen(c) == len(optname)) && !option.OptionalArgument + + if err := p.parseOption(s, shortname, option, canarg, argument); err != nil { + return err + } + } else { + return newErrorf(ErrUnknownFlag, "unknown flag `%s'", shortname) + } + + // Only the first option can have a concatted argument, so just + // clear argument here + argument = nil + } + + return nil +} + +func (p *parseState) addArgs(args ...string) error { + for len(p.positional) > 0 && len(args) > 0 { + arg := p.positional[0] + + if err := convert(args[0], arg.value, arg.tag); err != nil { + return err + } + + if !arg.isRemaining() { + p.positional = p.positional[1:] + } + + args = args[1:] + } + + p.retargs = append(p.retargs, args...) + return nil +} + +func (p *Parser) parseNonOption(s *parseState) error { + if len(s.positional) > 0 { + return s.addArgs(s.arg) + } + + if cmd := s.lookup.commands[s.arg]; cmd != nil { + s.command.Active = cmd + cmd.fillParseState(s) + } else if (p.Options & PassAfterNonOption) != None { + // If PassAfterNonOption is set then all remaining arguments + // are considered positional + if err := s.addArgs(s.arg); err != nil { + return err + } + + if err := s.addArgs(s.args...); err != nil { + return err + } + + s.args = []string{} + } else { + return s.addArgs(s.arg) + } + + return nil +} + +func (p *Parser) showBuiltinHelp() error { + var b bytes.Buffer + + p.WriteHelp(&b) + return newError(ErrHelp, b.String()) +} + +func (p *Parser) printError(err error) error { + if err != nil && (p.Options&PrintErrors) != None { + fmt.Fprintln(os.Stderr, err) + } + + return err +} + +func (p *Parser) clearIsSet() { + p.eachCommand(func(c *Command) { + c.eachGroup(func(g *Group) { + for _, option := range g.options { + option.isSet = false + } + }) + }, true) +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/parser_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/parser_test.go new file mode 100644 index 000000000..579287262 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/parser_test.go @@ -0,0 +1,431 @@ +package flags + +import ( + "fmt" + "os" + "reflect" + "strconv" + "strings" + "testing" + "time" +) + +type defaultOptions struct { + Int int `long:"i"` + IntDefault int `long:"id" default:"1"` + + String string `long:"str"` + StringDefault string `long:"strd" default:"abc"` + StringNotUnquoted string `long:"strnot" unquote:"false"` + + Time time.Duration `long:"t"` + TimeDefault time.Duration `long:"td" default:"1m"` + + Map map[string]int `long:"m"` + MapDefault map[string]int `long:"md" default:"a:1"` + + Slice []int `long:"s"` + SliceDefault []int `long:"sd" default:"1" default:"2"` +} + +func TestDefaults(t *testing.T) { + var tests = []struct { + msg string + args []string + expected defaultOptions + }{ + { + msg: "no arguments, expecting default values", + args: []string{}, + expected: defaultOptions{ + Int: 0, + IntDefault: 1, + + String: "", + StringDefault: "abc", + + Time: 0, + TimeDefault: time.Minute, + + Map: map[string]int{}, + MapDefault: map[string]int{"a": 1}, + + Slice: []int{}, + SliceDefault: []int{1, 2}, + }, + }, + { + msg: "non-zero value arguments, expecting overwritten arguments", + args: []string{"--i=3", "--id=3", "--str=def", "--strd=def", "--t=3ms", "--td=3ms", "--m=c:3", "--md=c:3", "--s=3", "--sd=3"}, + expected: defaultOptions{ + Int: 3, + IntDefault: 3, + + String: "def", + StringDefault: "def", + + Time: 3 * time.Millisecond, + TimeDefault: 3 * time.Millisecond, + + Map: map[string]int{"c": 3}, + MapDefault: map[string]int{"c": 3}, + + Slice: []int{3}, + SliceDefault: []int{3}, + }, + }, + { + msg: "zero value arguments, expecting overwritten arguments", + args: []string{"--i=0", "--id=0", "--str", "", "--strd=\"\"", "--t=0ms", "--td=0s", "--m=:0", "--md=:0", "--s=0", "--sd=0"}, + expected: defaultOptions{ + Int: 0, + IntDefault: 0, + + String: "", + StringDefault: "", + + Time: 0, + TimeDefault: 0, + + Map: map[string]int{"": 0}, + MapDefault: map[string]int{"": 0}, + + Slice: []int{0}, + SliceDefault: []int{0}, + }, + }, + } + + for _, test := range tests { + var opts defaultOptions + + _, err := ParseArgs(&opts, test.args) + if err != nil { + t.Fatalf("%s:\nUnexpected error: %v", test.msg, err) + } + + if opts.Slice == nil { + opts.Slice = []int{} + } + + if !reflect.DeepEqual(opts, test.expected) { + t.Errorf("%s:\nUnexpected options with arguments %+v\nexpected\n%+v\nbut got\n%+v\n", test.msg, test.args, test.expected, opts) + } + } +} + +func TestUnquoting(t *testing.T) { + var tests = []struct { + arg string + err error + value string + }{ + { + arg: "\"abc", + err: strconv.ErrSyntax, + value: "", + }, + { + arg: "\"\"abc\"", + err: strconv.ErrSyntax, + value: "", + }, + { + arg: "\"abc\"", + err: nil, + value: "abc", + }, + { + arg: "\"\\\"abc\\\"\"", + err: nil, + value: "\"abc\"", + }, + { + arg: "\"\\\"abc\"", + err: nil, + value: "\"abc", + }, + } + + for _, test := range tests { + var opts defaultOptions + + for _, delimiter := range []bool{false, true} { + p := NewParser(&opts, None) + + var err error + if delimiter { + _, err = p.ParseArgs([]string{"--str=" + test.arg, "--strnot=" + test.arg}) + } else { + _, err = p.ParseArgs([]string{"--str", test.arg, "--strnot", test.arg}) + } + + if test.err == nil { + if err != nil { + t.Fatalf("Expected no error but got: %v", err) + } + + if test.value != opts.String { + t.Fatalf("Expected String to be %q but got %q", test.value, opts.String) + } + if q := strconv.Quote(test.value); q != opts.StringNotUnquoted { + t.Fatalf("Expected StringDefault to be %q but got %q", q, opts.StringNotUnquoted) + } + } else { + if err == nil { + t.Fatalf("Expected error") + } else if e, ok := err.(*Error); ok { + if strings.HasPrefix(e.Message, test.err.Error()) { + t.Fatalf("Expected error message to end with %q but got %v", test.err.Error(), e.Message) + } + } + } + } + } +} + +// envRestorer keeps a copy of a set of env variables and can restore the env from them +type envRestorer struct { + env map[string]string +} + +func (r *envRestorer) Restore() { + os.Clearenv() + for k, v := range r.env { + os.Setenv(k, v) + } +} + +// EnvSnapshot returns a snapshot of the currently set env variables +func EnvSnapshot() *envRestorer { + r := envRestorer{make(map[string]string)} + for _, kv := range os.Environ() { + parts := strings.SplitN(kv, "=", 2) + if len(parts) != 2 { + panic("got a weird env variable: " + kv) + } + r.env[parts[0]] = parts[1] + } + return &r +} + +type envDefaultOptions struct { + Int int `long:"i" default:"1" env:"TEST_I"` + Time time.Duration `long:"t" default:"1m" env:"TEST_T"` + Map map[string]int `long:"m" default:"a:1" env:"TEST_M" env-delim:";"` + Slice []int `long:"s" default:"1" default:"2" env:"TEST_S" env-delim:","` +} + +func TestEnvDefaults(t *testing.T) { + var tests = []struct { + msg string + args []string + expected envDefaultOptions + env map[string]string + }{ + { + msg: "no arguments, no env, expecting default values", + args: []string{}, + expected: envDefaultOptions{ + Int: 1, + Time: time.Minute, + Map: map[string]int{"a": 1}, + Slice: []int{1, 2}, + }, + }, + { + msg: "no arguments, env defaults, expecting env default values", + args: []string{}, + expected: envDefaultOptions{ + Int: 2, + Time: 2 * time.Minute, + Map: map[string]int{"a": 2, "b": 3}, + Slice: []int{4, 5, 6}, + }, + env: map[string]string{ + "TEST_I": "2", + "TEST_T": "2m", + "TEST_M": "a:2;b:3", + "TEST_S": "4,5,6", + }, + }, + { + msg: "non-zero value arguments, expecting overwritten arguments", + args: []string{"--i=3", "--t=3ms", "--m=c:3", "--s=3"}, + expected: envDefaultOptions{ + Int: 3, + Time: 3 * time.Millisecond, + Map: map[string]int{"c": 3}, + Slice: []int{3}, + }, + env: map[string]string{ + "TEST_I": "2", + "TEST_T": "2m", + "TEST_M": "a:2;b:3", + "TEST_S": "4,5,6", + }, + }, + { + msg: "zero value arguments, expecting overwritten arguments", + args: []string{"--i=0", "--t=0ms", "--m=:0", "--s=0"}, + expected: envDefaultOptions{ + Int: 0, + Time: 0, + Map: map[string]int{"": 0}, + Slice: []int{0}, + }, + env: map[string]string{ + "TEST_I": "2", + "TEST_T": "2m", + "TEST_M": "a:2;b:3", + "TEST_S": "4,5,6", + }, + }, + } + + oldEnv := EnvSnapshot() + defer oldEnv.Restore() + + for _, test := range tests { + var opts envDefaultOptions + oldEnv.Restore() + for envKey, envValue := range test.env { + os.Setenv(envKey, envValue) + } + _, err := ParseArgs(&opts, test.args) + if err != nil { + t.Fatalf("%s:\nUnexpected error: %v", test.msg, err) + } + + if opts.Slice == nil { + opts.Slice = []int{} + } + + if !reflect.DeepEqual(opts, test.expected) { + t.Errorf("%s:\nUnexpected options with arguments %+v\nexpected\n%+v\nbut got\n%+v\n", test.msg, test.args, test.expected, opts) + } + } +} + +func TestOptionAsArgument(t *testing.T) { + var tests = []struct { + args []string + expectError bool + errType ErrorType + errMsg string + rest []string + }{ + { + // short option must not be accepted as argument + args: []string{"--string-slice", "foobar", "--string-slice", "-o"}, + expectError: true, + errType: ErrExpectedArgument, + errMsg: "expected argument for flag `--string-slice', but got option `-o'", + }, + { + // long option must not be accepted as argument + args: []string{"--string-slice", "foobar", "--string-slice", "--other-option"}, + expectError: true, + errType: ErrExpectedArgument, + errMsg: "expected argument for flag `--string-slice', but got option `--other-option'", + }, + { + // long option must not be accepted as argument + args: []string{"--string-slice", "--"}, + expectError: true, + errType: ErrExpectedArgument, + errMsg: "expected argument for flag `--string-slice', but got double dash `--'", + }, + { + // quoted and appended option should be accepted as argument (even if it looks like an option) + args: []string{"--string-slice", "foobar", "--string-slice=\"--other-option\""}, + }, + { + // Accept any single character arguments including '-' + args: []string{"--string-slice", "-"}, + }, + { + args: []string{"-o", "-", "-"}, + rest: []string{"-", "-"}, + }, + } + var opts struct { + StringSlice []string `long:"string-slice"` + OtherOption bool `long:"other-option" short:"o"` + } + + for _, test := range tests { + if test.expectError { + assertParseFail(t, test.errType, test.errMsg, &opts, test.args...) + } else { + args := assertParseSuccess(t, &opts, test.args...) + + assertStringArray(t, args, test.rest) + } + } +} + +func TestUnknownFlagHandler(t *testing.T) { + + var opts struct { + Flag1 string `long:"flag1"` + Flag2 string `long:"flag2"` + } + + p := NewParser(&opts, None) + + var unknownFlag1 string + var unknownFlag2 bool + var unknownFlag3 string + + // Set up a callback to intercept unknown options during parsing + p.UnknownOptionHandler = func(option string, arg SplitArgument, args []string) ([]string, error) { + if option == "unknownFlag1" { + if argValue, ok := arg.Value(); ok { + unknownFlag1 = argValue + return args, nil + } + // consume a value from remaining args list + unknownFlag1 = args[0] + return args[1:], nil + } else if option == "unknownFlag2" { + // treat this one as a bool switch, don't consume any args + unknownFlag2 = true + return args, nil + } else if option == "unknownFlag3" { + if argValue, ok := arg.Value(); ok { + unknownFlag3 = argValue + return args, nil + } + // consume a value from remaining args list + unknownFlag3 = args[0] + return args[1:], nil + } + + return args, fmt.Errorf("Unknown flag: %v", option) + } + + // Parse args containing some unknown flags, verify that + // our callback can handle all of them + _, err := p.ParseArgs([]string{"--flag1=stuff", "--unknownFlag1", "blah", "--unknownFlag2", "--unknownFlag3=baz", "--flag2=foo"}) + + if err != nil { + assertErrorf(t, "Parser returned unexpected error %v", err) + } + + assertString(t, opts.Flag1, "stuff") + assertString(t, opts.Flag2, "foo") + assertString(t, unknownFlag1, "blah") + assertString(t, unknownFlag3, "baz") + + if !unknownFlag2 { + assertErrorf(t, "Flag should have been set by unknown handler, but had value: %v", unknownFlag2) + } + + // Parse args with unknown flags that callback doesn't handle, verify it returns error + _, err = p.ParseArgs([]string{"--flag1=stuff", "--unknownFlagX", "blah", "--flag2=foo"}) + + if err == nil { + assertErrorf(t, "Parser should have returned error, but returned nil") + } +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/pointer_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/pointer_test.go new file mode 100644 index 000000000..e17445f69 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/pointer_test.go @@ -0,0 +1,81 @@ +package flags + +import ( + "testing" +) + +func TestPointerBool(t *testing.T) { + var opts = struct { + Value *bool `short:"v"` + }{} + + ret := assertParseSuccess(t, &opts, "-v") + + assertStringArray(t, ret, []string{}) + + if !*opts.Value { + t.Errorf("Expected Value to be true") + } +} + +func TestPointerString(t *testing.T) { + var opts = struct { + Value *string `short:"v"` + }{} + + ret := assertParseSuccess(t, &opts, "-v", "value") + + assertStringArray(t, ret, []string{}) + assertString(t, *opts.Value, "value") +} + +func TestPointerSlice(t *testing.T) { + var opts = struct { + Value *[]string `short:"v"` + }{} + + ret := assertParseSuccess(t, &opts, "-v", "value1", "-v", "value2") + + assertStringArray(t, ret, []string{}) + assertStringArray(t, *opts.Value, []string{"value1", "value2"}) +} + +func TestPointerMap(t *testing.T) { + var opts = struct { + Value *map[string]int `short:"v"` + }{} + + ret := assertParseSuccess(t, &opts, "-v", "k1:2", "-v", "k2:-5") + + assertStringArray(t, ret, []string{}) + + if v, ok := (*opts.Value)["k1"]; !ok { + t.Errorf("Expected key \"k1\" to exist") + } else if v != 2 { + t.Errorf("Expected \"k1\" to be 2, but got %#v", v) + } + + if v, ok := (*opts.Value)["k2"]; !ok { + t.Errorf("Expected key \"k2\" to exist") + } else if v != -5 { + t.Errorf("Expected \"k2\" to be -5, but got %#v", v) + } +} + +type PointerGroup struct { + Value bool `short:"v"` +} + +func TestPointerGroup(t *testing.T) { + var opts = struct { + Group *PointerGroup `group:"Group Options"` + }{} + + ret := assertParseSuccess(t, &opts, "-v") + + assertStringArray(t, ret, []string{}) + + if !opts.Group.Value { + t.Errorf("Expected Group.Value to be true") + } +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/short_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/short_test.go new file mode 100644 index 000000000..95712c162 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/short_test.go @@ -0,0 +1,194 @@ +package flags + +import ( + "fmt" + "testing" +) + +func TestShort(t *testing.T) { + var opts = struct { + Value bool `short:"v"` + }{} + + ret := assertParseSuccess(t, &opts, "-v") + + assertStringArray(t, ret, []string{}) + + if !opts.Value { + t.Errorf("Expected Value to be true") + } +} + +func TestShortTooLong(t *testing.T) { + var opts = struct { + Value bool `short:"vv"` + }{} + + assertParseFail(t, ErrShortNameTooLong, "short names can only be 1 character long, not `vv'", &opts) +} + +func TestShortRequired(t *testing.T) { + var opts = struct { + Value bool `short:"v" required:"true"` + }{} + + assertParseFail(t, ErrRequired, fmt.Sprintf("the required flag `%cv' was not specified", defaultShortOptDelimiter), &opts) +} + +func TestShortMultiConcat(t *testing.T) { + var opts = struct { + V bool `short:"v"` + O bool `short:"o"` + F bool `short:"f"` + }{} + + ret := assertParseSuccess(t, &opts, "-vo", "-f") + + assertStringArray(t, ret, []string{}) + + if !opts.V { + t.Errorf("Expected V to be true") + } + + if !opts.O { + t.Errorf("Expected O to be true") + } + + if !opts.F { + t.Errorf("Expected F to be true") + } +} + +func TestShortMultiRequiredConcat(t *testing.T) { + var opts = struct { + V bool `short:"v" required:"true"` + O bool `short:"o" required:"true"` + F bool `short:"f" required:"true"` + }{} + + ret := assertParseSuccess(t, &opts, "-vo", "-f") + + assertStringArray(t, ret, []string{}) + + if !opts.V { + t.Errorf("Expected V to be true") + } + + if !opts.O { + t.Errorf("Expected O to be true") + } + + if !opts.F { + t.Errorf("Expected F to be true") + } +} + +func TestShortMultiSlice(t *testing.T) { + var opts = struct { + Values []bool `short:"v"` + }{} + + ret := assertParseSuccess(t, &opts, "-v", "-v") + + assertStringArray(t, ret, []string{}) + assertBoolArray(t, opts.Values, []bool{true, true}) +} + +func TestShortMultiSliceConcat(t *testing.T) { + var opts = struct { + Values []bool `short:"v"` + }{} + + ret := assertParseSuccess(t, &opts, "-vvv") + + assertStringArray(t, ret, []string{}) + assertBoolArray(t, opts.Values, []bool{true, true, true}) +} + +func TestShortWithEqualArg(t *testing.T) { + var opts = struct { + Value string `short:"v"` + }{} + + ret := assertParseSuccess(t, &opts, "-v=value") + + assertStringArray(t, ret, []string{}) + assertString(t, opts.Value, "value") +} + +func TestShortWithArg(t *testing.T) { + var opts = struct { + Value string `short:"v"` + }{} + + ret := assertParseSuccess(t, &opts, "-vvalue") + + assertStringArray(t, ret, []string{}) + assertString(t, opts.Value, "value") +} + +func TestShortArg(t *testing.T) { + var opts = struct { + Value string `short:"v"` + }{} + + ret := assertParseSuccess(t, &opts, "-v", "value") + + assertStringArray(t, ret, []string{}) + assertString(t, opts.Value, "value") +} + +func TestShortMultiWithEqualArg(t *testing.T) { + var opts = struct { + F []bool `short:"f"` + Value string `short:"v"` + }{} + + assertParseFail(t, ErrExpectedArgument, fmt.Sprintf("expected argument for flag `%cv'", defaultShortOptDelimiter), &opts, "-ffv=value") +} + +func TestShortMultiArg(t *testing.T) { + var opts = struct { + F []bool `short:"f"` + Value string `short:"v"` + }{} + + ret := assertParseSuccess(t, &opts, "-ffv", "value") + + assertStringArray(t, ret, []string{}) + assertBoolArray(t, opts.F, []bool{true, true}) + assertString(t, opts.Value, "value") +} + +func TestShortMultiArgConcatFail(t *testing.T) { + var opts = struct { + F []bool `short:"f"` + Value string `short:"v"` + }{} + + assertParseFail(t, ErrExpectedArgument, fmt.Sprintf("expected argument for flag `%cv'", defaultShortOptDelimiter), &opts, "-ffvvalue") +} + +func TestShortMultiArgConcat(t *testing.T) { + var opts = struct { + F []bool `short:"f"` + Value string `short:"v"` + }{} + + ret := assertParseSuccess(t, &opts, "-vff") + + assertStringArray(t, ret, []string{}) + assertString(t, opts.Value, "ff") +} + +func TestShortOptional(t *testing.T) { + var opts = struct { + F []bool `short:"f"` + Value string `short:"v" optional:"yes" optional-value:"value"` + }{} + + ret := assertParseSuccess(t, &opts, "-fv", "f") + + assertStringArray(t, ret, []string{"f"}) + assertString(t, opts.Value, "value") +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/tag_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/tag_test.go new file mode 100644 index 000000000..9daa7401b --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/tag_test.go @@ -0,0 +1,38 @@ +package flags + +import ( + "testing" +) + +func TestTagMissingColon(t *testing.T) { + var opts = struct { + Value bool `short` + }{} + + assertParseFail(t, ErrTag, "expected `:' after key name, but got end of tag (in `short`)", &opts, "") +} + +func TestTagMissingValue(t *testing.T) { + var opts = struct { + Value bool `short:` + }{} + + assertParseFail(t, ErrTag, "expected `\"' to start tag value at end of tag (in `short:`)", &opts, "") +} + +func TestTagMissingQuote(t *testing.T) { + var opts = struct { + Value bool `short:"v` + }{} + + assertParseFail(t, ErrTag, "expected end of tag value `\"' at end of tag (in `short:\"v`)", &opts, "") +} + +func TestTagNewline(t *testing.T) { + var opts = struct { + Value bool `long:"verbose" description:"verbose +something"` + }{} + + assertParseFail(t, ErrTag, "unexpected newline in tag value `description' (in `long:\"verbose\" description:\"verbose\nsomething\"`)", &opts, "") +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize.go new file mode 100644 index 000000000..df97e7e82 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize.go @@ -0,0 +1,28 @@ +// +build !windows,!plan9,!solaris + +package flags + +import ( + "syscall" + "unsafe" +) + +type winsize struct { + row, col uint16 + xpixel, ypixel uint16 +} + +func getTerminalColumns() int { + ws := winsize{} + + if tIOCGWINSZ != 0 { + syscall.Syscall(syscall.SYS_IOCTL, + uintptr(0), + uintptr(tIOCGWINSZ), + uintptr(unsafe.Pointer(&ws))) + + return int(ws.col) + } + + return 80 +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize_linux.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize_linux.go new file mode 100644 index 000000000..e3975e283 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize_linux.go @@ -0,0 +1,7 @@ +// +build linux + +package flags + +const ( + tIOCGWINSZ = 0x5413 +) diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize_nosysioctl.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize_nosysioctl.go new file mode 100644 index 000000000..2a9bbe005 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize_nosysioctl.go @@ -0,0 +1,7 @@ +// +build windows plan9 solaris + +package flags + +func getTerminalColumns() int { + return 80 +} diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize_other.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize_other.go new file mode 100644 index 000000000..308215155 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize_other.go @@ -0,0 +1,7 @@ +// +build !darwin,!freebsd,!netbsd,!openbsd,!linux + +package flags + +const ( + tIOCGWINSZ = 0 +) diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize_unix.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize_unix.go new file mode 100644 index 000000000..fcc118601 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/termsize_unix.go @@ -0,0 +1,7 @@ +// +build darwin freebsd netbsd openbsd + +package flags + +const ( + tIOCGWINSZ = 0x40087468 +) diff --git a/Godeps/_workspace/src/github.com/jessevdk/go-flags/unknown_test.go b/Godeps/_workspace/src/github.com/jessevdk/go-flags/unknown_test.go new file mode 100644 index 000000000..858be4588 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jessevdk/go-flags/unknown_test.go @@ -0,0 +1,66 @@ +package flags + +import ( + "testing" +) + +func TestUnknownFlags(t *testing.T) { + var opts = struct { + Verbose []bool `short:"v" long:"verbose" description:"Verbose output"` + }{} + + args := []string{ + "-f", + } + + p := NewParser(&opts, 0) + args, err := p.ParseArgs(args) + + if err == nil { + t.Fatal("Expected error for unknown argument") + } +} + +func TestIgnoreUnknownFlags(t *testing.T) { + var opts = struct { + Verbose []bool `short:"v" long:"verbose" description:"Verbose output"` + }{} + + args := []string{ + "hello", + "world", + "-v", + "--foo=bar", + "--verbose", + "-f", + } + + p := NewParser(&opts, IgnoreUnknown) + args, err := p.ParseArgs(args) + + if err != nil { + t.Fatal(err) + } + + exargs := []string{ + "hello", + "world", + "--foo=bar", + "-f", + } + + issame := (len(args) == len(exargs)) + + if issame { + for i := 0; i < len(args); i++ { + if args[i] != exargs[i] { + issame = false + break + } + } + } + + if !issame { + t.Fatalf("Expected %v but got %v", exargs, args) + } +} diff --git a/Godeps/_workspace/src/github.com/juju/errors/.gitignore b/Godeps/_workspace/src/github.com/juju/errors/.gitignore new file mode 100644 index 000000000..836562412 --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/.gitignore @@ -0,0 +1,23 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test diff --git a/Godeps/_workspace/src/github.com/juju/errors/LICENSE b/Godeps/_workspace/src/github.com/juju/errors/LICENSE new file mode 100644 index 000000000..ade9307b3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/LICENSE @@ -0,0 +1,191 @@ +All files in this repository are licensed as follows. If you contribute +to this repository, it is assumed that you license your contribution +under the same license unless you state otherwise. + +All files Copyright (C) 2015 Canonical Ltd. unless otherwise specified in the file. + +This software is licensed under the LGPLv3, included below. + +As a special exception to the GNU Lesser General Public License version 3 +("LGPL3"), the copyright holders of this Library give you permission to +convey to a third party a Combined Work that links statically or dynamically +to this Library without providing any Minimal Corresponding Source or +Minimal Application Code as set out in 4d or providing the installation +information set out in section 4e, provided that you comply with the other +provisions of LGPL3 and provided that you meet, for the Application the +terms and conditions of the license(s) which apply to the Application. + +Except as stated in this special exception, the provisions of LGPL3 will +continue to comply in full to this Library. If you modify this Library, you +may apply this exception to your version of this Library, but you are not +obliged to do so. If you do not wish to do so, delete this exception +statement from your version. This exception does not (and cannot) modify any +license terms which apply to the Application, with which you must still +comply. + + + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/Godeps/_workspace/src/github.com/juju/errors/Makefile b/Godeps/_workspace/src/github.com/juju/errors/Makefile new file mode 100644 index 000000000..ab7c2e6cf --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/Makefile @@ -0,0 +1,11 @@ +default: check + +check: + go test && go test -compiler gccgo + +docs: + godoc2md github.com/juju/errors > README.md + sed -i 's|\[godoc-link-here\]|[![GoDoc](https://godoc.org/github.com/juju/errors?status.svg)](https://godoc.org/github.com/juju/errors)|' README.md + + +.PHONY: default check docs diff --git a/Godeps/_workspace/src/github.com/juju/errors/README.md b/Godeps/_workspace/src/github.com/juju/errors/README.md new file mode 100644 index 000000000..ee248911f --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/README.md @@ -0,0 +1,536 @@ + +# errors + import "github.com/juju/errors" + +[![GoDoc](https://godoc.org/github.com/juju/errors?status.svg)](https://godoc.org/github.com/juju/errors) + +The juju/errors provides an easy way to annotate errors without losing the +orginal error context. + +The exported `New` and `Errorf` functions are designed to replace the +`errors.New` and `fmt.Errorf` functions respectively. The same underlying +error is there, but the package also records the location at which the error +was created. + +A primary use case for this library is to add extra context any time an +error is returned from a function. + + + if err := SomeFunc(); err != nil { + return err + } + +This instead becomes: + + + if err := SomeFunc(); err != nil { + return errors.Trace(err) + } + +which just records the file and line number of the Trace call, or + + + if err := SomeFunc(); err != nil { + return errors.Annotate(err, "more context") + } + +which also adds an annotation to the error. + +When you want to check to see if an error is of a particular type, a helper +function is normally exported by the package that returned the error, like the +`os` package does. The underlying cause of the error is available using the +`Cause` function. + + + os.IsNotExist(errors.Cause(err)) + +The result of the `Error()` call on an annotated error is the annotations joined +with colons, then the result of the `Error()` method for the underlying error +that was the cause. + + + err := errors.Errorf("original") + err = errors.Annotatef(err, "context") + err = errors.Annotatef(err, "more context") + err.Error() -> "more context: context: original" + +Obviously recording the file, line and functions is not very useful if you +cannot get them back out again. + + + errors.ErrorStack(err) + +will return something like: + + + first error + github.com/juju/errors/annotation_test.go:193: + github.com/juju/errors/annotation_test.go:194: annotation + github.com/juju/errors/annotation_test.go:195: + github.com/juju/errors/annotation_test.go:196: more context + github.com/juju/errors/annotation_test.go:197: + +The first error was generated by an external system, so there was no location +associated. The second, fourth, and last lines were generated with Trace calls, +and the other two through Annotate. + +Sometimes when responding to an error you want to return a more specific error +for the situation. + + + if err := FindField(field); err != nil { + return errors.Wrap(err, errors.NotFoundf(field)) + } + +This returns an error where the complete error stack is still available, and +`errors.Cause()` will return the `NotFound` error. + + + + + + +## func AlreadyExistsf +``` go +func AlreadyExistsf(format string, args ...interface{}) error +``` +AlreadyExistsf returns an error which satisfies IsAlreadyExists(). + + +## func Annotate +``` go +func Annotate(other error, message string) error +``` +Annotate is used to add extra context to an existing error. The location of +the Annotate call is recorded with the annotations. The file, line and +function are also recorded. + +For example: + + + if err := SomeFunc(); err != nil { + return errors.Annotate(err, "failed to frombulate") + } + + +## func Annotatef +``` go +func Annotatef(other error, format string, args ...interface{}) error +``` +Annotatef is used to add extra context to an existing error. The location of +the Annotate call is recorded with the annotations. The file, line and +function are also recorded. + +For example: + + + if err := SomeFunc(); err != nil { + return errors.Annotatef(err, "failed to frombulate the %s", arg) + } + + +## func Cause +``` go +func Cause(err error) error +``` +Cause returns the cause of the given error. This will be either the +original error, or the result of a Wrap or Mask call. + +Cause is the usual way to diagnose errors that may have been wrapped by +the other errors functions. + + +## func DeferredAnnotatef +``` go +func DeferredAnnotatef(err *error, format string, args ...interface{}) +``` +DeferredAnnotatef annotates the given error (when it is not nil) with the given +format string and arguments (like fmt.Sprintf). If *err is nil, DeferredAnnotatef +does nothing. This method is used in a defer statement in order to annotate any +resulting error with the same message. + +For example: + + + defer DeferredAnnotatef(&err, "failed to frombulate the %s", arg) + + +## func Details +``` go +func Details(err error) string +``` +Details returns information about the stack of errors wrapped by err, in +the format: + + + [{filename:99: error one} {otherfile:55: cause of error one}] + +This is a terse alternative to ErrorStack as it returns a single line. + + +## func ErrorStack +``` go +func ErrorStack(err error) string +``` +ErrorStack returns a string representation of the annotated error. If the +error passed as the parameter is not an annotated error, the result is +simply the result of the Error() method on that error. + +If the error is an annotated error, a multi-line string is returned where +each line represents one entry in the annotation stack. The full filename +from the call stack is used in the output. + + + first error + github.com/juju/errors/annotation_test.go:193: + github.com/juju/errors/annotation_test.go:194: annotation + github.com/juju/errors/annotation_test.go:195: + github.com/juju/errors/annotation_test.go:196: more context + github.com/juju/errors/annotation_test.go:197: + + +## func Errorf +``` go +func Errorf(format string, args ...interface{}) error +``` +Errorf creates a new annotated error and records the location that the +error is created. This should be a drop in replacement for fmt.Errorf. + +For example: + + + return errors.Errorf("validation failed: %s", message) + + +## func IsAlreadyExists +``` go +func IsAlreadyExists(err error) bool +``` +IsAlreadyExists reports whether the error was created with +AlreadyExistsf() or NewAlreadyExists(). + + +## func IsNotFound +``` go +func IsNotFound(err error) bool +``` +IsNotFound reports whether err was created with NotFoundf() or +NewNotFound(). + + +## func IsNotImplemented +``` go +func IsNotImplemented(err error) bool +``` +IsNotImplemented reports whether err was created with +NotImplementedf() or NewNotImplemented(). + + +## func IsNotSupported +``` go +func IsNotSupported(err error) bool +``` +IsNotSupported reports whether the error was created with +NotSupportedf() or NewNotSupported(). + + +## func IsNotValid +``` go +func IsNotValid(err error) bool +``` +IsNotValid reports whether the error was created with NotValidf() or +NewNotValid(). + + +## func IsUnauthorized +``` go +func IsUnauthorized(err error) bool +``` +IsUnauthorized reports whether err was created with Unauthorizedf() or +NewUnauthorized(). + + +## func Mask +``` go +func Mask(other error) error +``` +Mask hides the underlying error type, and records the location of the masking. + + +## func Maskf +``` go +func Maskf(other error, format string, args ...interface{}) error +``` +Mask masks the given error with the given format string and arguments (like +fmt.Sprintf), returning a new error that maintains the error stack, but +hides the underlying error type. The error string still contains the full +annotations. If you want to hide the annotations, call Wrap. + + +## func New +``` go +func New(message string) error +``` +New is a drop in replacement for the standard libary errors module that records +the location that the error is created. + +For example: + + + return errors.New("validation failed") + + +## func NewAlreadyExists +``` go +func NewAlreadyExists(err error, msg string) error +``` +NewAlreadyExists returns an error which wraps err and satisfies +IsAlreadyExists(). + + +## func NewNotFound +``` go +func NewNotFound(err error, msg string) error +``` +NewNotFound returns an error which wraps err that satisfies +IsNotFound(). + + +## func NewNotImplemented +``` go +func NewNotImplemented(err error, msg string) error +``` +NewNotImplemented returns an error which wraps err and satisfies +IsNotImplemented(). + + +## func NewNotSupported +``` go +func NewNotSupported(err error, msg string) error +``` +NewNotSupported returns an error which wraps err and satisfies +IsNotSupported(). + + +## func NewNotValid +``` go +func NewNotValid(err error, msg string) error +``` +NewNotValid returns an error which wraps err and satisfies IsNotValid(). + + +## func NewUnauthorized +``` go +func NewUnauthorized(err error, msg string) error +``` +NewUnauthorized returns an error which wraps err and satisfies +IsUnauthorized(). + + +## func NotFoundf +``` go +func NotFoundf(format string, args ...interface{}) error +``` +NotFoundf returns an error which satisfies IsNotFound(). + + +## func NotImplementedf +``` go +func NotImplementedf(format string, args ...interface{}) error +``` +NotImplementedf returns an error which satisfies IsNotImplemented(). + + +## func NotSupportedf +``` go +func NotSupportedf(format string, args ...interface{}) error +``` +NotSupportedf returns an error which satisfies IsNotSupported(). + + +## func NotValidf +``` go +func NotValidf(format string, args ...interface{}) error +``` +NotValidf returns an error which satisfies IsNotValid(). + + +## func Trace +``` go +func Trace(other error) error +``` +Trace adds the location of the Trace call to the stack. The Cause of the +resulting error is the same as the error parameter. If the other error is +nil, the result will be nil. + +For example: + + + if err := SomeFunc(); err != nil { + return errors.Trace(err) + } + + +## func Unauthorizedf +``` go +func Unauthorizedf(format string, args ...interface{}) error +``` +Unauthorizedf returns an error which satisfies IsUnauthorized(). + + +## func Wrap +``` go +func Wrap(other, newDescriptive error) error +``` +Wrap changes the Cause of the error. The location of the Wrap call is also +stored in the error stack. + +For example: + + + if err := SomeFunc(); err != nil { + newErr := &packageError{"more context", private_value} + return errors.Wrap(err, newErr) + } + + +## func Wrapf +``` go +func Wrapf(other, newDescriptive error, format string, args ...interface{}) error +``` +Wrapf changes the Cause of the error, and adds an annotation. The location +of the Wrap call is also stored in the error stack. + +For example: + + + if err := SomeFunc(); err != nil { + return errors.Wrapf(err, simpleErrorType, "invalid value %q", value) + } + + + +## type Err +``` go +type Err struct { + // contains filtered or unexported fields +} +``` +Err holds a description of an error along with information about +where the error was created. + +It may be embedded in custom error types to add extra information that +this errors package can understand. + + + + + + + + + +### func NewErr +``` go +func NewErr(format string, args ...interface{}) Err +``` +NewErr is used to return an Err for the purpose of embedding in other +structures. The location is not specified, and needs to be set with a call +to SetLocation. + +For example: + + + type FooError struct { + errors.Err + code int + } + + func NewFooError(code int) error { + err := &FooError{errors.NewErr("foo"), code} + err.SetLocation(1) + return err + } + + + + +### func (\*Err) Cause +``` go +func (e *Err) Cause() error +``` +The Cause of an error is the most recent error in the error stack that +meets one of these criteria: the original error that was raised; the new +error that was passed into the Wrap function; the most recently masked +error; or nil if the error itself is considered the Cause. Normally this +method is not invoked directly, but instead through the Cause stand alone +function. + + + +### func (\*Err) Error +``` go +func (e *Err) Error() string +``` +Error implements error.Error. + + + +### func (\*Err) Location +``` go +func (e *Err) Location() (filename string, line int) +``` +Location is the file and line of where the error was most recently +created or annotated. + + + +### func (\*Err) Message +``` go +func (e *Err) Message() string +``` +Message returns the message stored with the most recent location. This is +the empty string if the most recent call was Trace, or the message stored +with Annotate or Mask. + + + +### func (\*Err) SetLocation +``` go +func (e *Err) SetLocation(callDepth int) +``` +SetLocation records the source location of the error at callDepth stack +frames above the call. + + + +### func (\*Err) StackTrace +``` go +func (e *Err) StackTrace() []string +``` +StackTrace returns one string for each location recorded in the stack of +errors. The first value is the originating error, with a line for each +other annotation or tracing of the error. + + + +### func (\*Err) Underlying +``` go +func (e *Err) Underlying() error +``` +Underlying returns the previous error in the error stack, if any. A client +should not ever really call this method. It is used to build the error +stack and should not be introspected by client calls. Or more +specifically, clients should not depend on anything but the `Cause` of an +error. + + + + + + + + + +- - - +Generated by [godoc2md](http://godoc.org/github.com/davecheney/godoc2md) \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/juju/errors/doc.go b/Godeps/_workspace/src/github.com/juju/errors/doc.go new file mode 100644 index 000000000..35b119aa3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/doc.go @@ -0,0 +1,81 @@ +// Copyright 2013, 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +/* +[godoc-link-here] + +The juju/errors provides an easy way to annotate errors without losing the +orginal error context. + +The exported `New` and `Errorf` functions are designed to replace the +`errors.New` and `fmt.Errorf` functions respectively. The same underlying +error is there, but the package also records the location at which the error +was created. + +A primary use case for this library is to add extra context any time an +error is returned from a function. + + if err := SomeFunc(); err != nil { + return err + } + +This instead becomes: + + if err := SomeFunc(); err != nil { + return errors.Trace(err) + } + +which just records the file and line number of the Trace call, or + + if err := SomeFunc(); err != nil { + return errors.Annotate(err, "more context") + } + +which also adds an annotation to the error. + +When you want to check to see if an error is of a particular type, a helper +function is normally exported by the package that returned the error, like the +`os` package does. The underlying cause of the error is available using the +`Cause` function. + + os.IsNotExist(errors.Cause(err)) + +The result of the `Error()` call on an annotated error is the annotations joined +with colons, then the result of the `Error()` method for the underlying error +that was the cause. + + err := errors.Errorf("original") + err = errors.Annotatef(err, "context") + err = errors.Annotatef(err, "more context") + err.Error() -> "more context: context: original" + +Obviously recording the file, line and functions is not very useful if you +cannot get them back out again. + + errors.ErrorStack(err) + +will return something like: + + first error + github.com/juju/errors/annotation_test.go:193: + github.com/juju/errors/annotation_test.go:194: annotation + github.com/juju/errors/annotation_test.go:195: + github.com/juju/errors/annotation_test.go:196: more context + github.com/juju/errors/annotation_test.go:197: + +The first error was generated by an external system, so there was no location +associated. The second, fourth, and last lines were generated with Trace calls, +and the other two through Annotate. + +Sometimes when responding to an error you want to return a more specific error +for the situation. + + if err := FindField(field); err != nil { + return errors.Wrap(err, errors.NotFoundf(field)) + } + +This returns an error where the complete error stack is still available, and +`errors.Cause()` will return the `NotFound` error. + +*/ +package errors diff --git a/Godeps/_workspace/src/github.com/juju/errors/error.go b/Godeps/_workspace/src/github.com/juju/errors/error.go new file mode 100644 index 000000000..4799acb02 --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/error.go @@ -0,0 +1,122 @@ +// Copyright 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors + +import ( + "fmt" + "reflect" + "runtime" +) + +// Err holds a description of an error along with information about +// where the error was created. +// +// It may be embedded in custom error types to add extra information that +// this errors package can understand. +type Err struct { + // message holds an annotation of the error. + message string + + // cause holds the cause of the error as returned + // by the Cause method. + cause error + + // previous holds the previous error in the error stack, if any. + previous error + + // file and line hold the source code location where the error was + // created. + file string + line int +} + +// NewErr is used to return an Err for the purpose of embedding in other +// structures. The location is not specified, and needs to be set with a call +// to SetLocation. +// +// For example: +// type FooError struct { +// errors.Err +// code int +// } +// +// func NewFooError(code int) error { +// err := &FooError{errors.NewErr("foo"), code} +// err.SetLocation(1) +// return err +// } +func NewErr(format string, args ...interface{}) Err { + return Err{ + message: fmt.Sprintf(format, args...), + } +} + +// Location is the file and line of where the error was most recently +// created or annotated. +func (e *Err) Location() (filename string, line int) { + return e.file, e.line +} + +// Underlying returns the previous error in the error stack, if any. A client +// should not ever really call this method. It is used to build the error +// stack and should not be introspected by client calls. Or more +// specifically, clients should not depend on anything but the `Cause` of an +// error. +func (e *Err) Underlying() error { + return e.previous +} + +// The Cause of an error is the most recent error in the error stack that +// meets one of these criteria: the original error that was raised; the new +// error that was passed into the Wrap function; the most recently masked +// error; or nil if the error itself is considered the Cause. Normally this +// method is not invoked directly, but instead through the Cause stand alone +// function. +func (e *Err) Cause() error { + return e.cause +} + +// Message returns the message stored with the most recent location. This is +// the empty string if the most recent call was Trace, or the message stored +// with Annotate or Mask. +func (e *Err) Message() string { + return e.message +} + +// Error implements error.Error. +func (e *Err) Error() string { + // We want to walk up the stack of errors showing the annotations + // as long as the cause is the same. + err := e.previous + if !sameError(Cause(err), e.cause) && e.cause != nil { + err = e.cause + } + switch { + case err == nil: + return e.message + case e.message == "": + return err.Error() + } + return fmt.Sprintf("%s: %v", e.message, err) +} + +// SetLocation records the source location of the error at callDepth stack +// frames above the call. +func (e *Err) SetLocation(callDepth int) { + _, file, line, _ := runtime.Caller(callDepth + 1) + e.file = trimGoPath(file) + e.line = line +} + +// StackTrace returns one string for each location recorded in the stack of +// errors. The first value is the originating error, with a line for each +// other annotation or tracing of the error. +func (e *Err) StackTrace() []string { + return errorStack(e) +} + +// Ideally we'd have a way to check identity, but deep equals will do. +func sameError(e1, e2 error) bool { + return reflect.DeepEqual(e1, e2) +} diff --git a/Godeps/_workspace/src/github.com/juju/errors/error_test.go b/Godeps/_workspace/src/github.com/juju/errors/error_test.go new file mode 100644 index 000000000..ac1d2b423 --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/error_test.go @@ -0,0 +1,161 @@ +// Copyright 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors_test + +import ( + "fmt" + "runtime" + + jc "github.com/juju/testing/checkers" + gc "gopkg.in/check.v1" + + "github.com/juju/errors" +) + +type errorsSuite struct{} + +var _ = gc.Suite(&errorsSuite{}) + +var someErr = errors.New("some error") //err varSomeErr + +func (*errorsSuite) TestErrorString(c *gc.C) { + for i, test := range []struct { + message string + generator func() error + expected string + }{ + { + message: "uncomparable errors", + generator: func() error { + err := errors.Annotatef(newNonComparableError("uncomparable"), "annotation") + return errors.Annotatef(err, "another") + }, + expected: "another: annotation: uncomparable", + }, { + message: "Errorf", + generator: func() error { + return errors.Errorf("first error") + }, + expected: "first error", + }, { + message: "annotated error", + generator: func() error { + err := errors.Errorf("first error") + return errors.Annotatef(err, "annotation") + }, + expected: "annotation: first error", + }, { + message: "test annotation format", + generator: func() error { + err := errors.Errorf("first %s", "error") + return errors.Annotatef(err, "%s", "annotation") + }, + expected: "annotation: first error", + }, { + message: "wrapped error", + generator: func() error { + err := newError("first error") + return errors.Wrap(err, newError("detailed error")) + }, + expected: "detailed error", + }, { + message: "wrapped annotated error", + generator: func() error { + err := errors.Errorf("first error") + err = errors.Annotatef(err, "annotated") + return errors.Wrap(err, fmt.Errorf("detailed error")) + }, + expected: "detailed error", + }, { + message: "annotated wrapped error", + generator: func() error { + err := errors.Errorf("first error") + err = errors.Wrap(err, fmt.Errorf("detailed error")) + return errors.Annotatef(err, "annotated") + }, + expected: "annotated: detailed error", + }, { + message: "traced, and annotated", + generator: func() error { + err := errors.New("first error") + err = errors.Trace(err) + err = errors.Annotate(err, "some context") + err = errors.Trace(err) + err = errors.Annotate(err, "more context") + return errors.Trace(err) + }, + expected: "more context: some context: first error", + }, { + message: "traced, and annotated, masked and annotated", + generator: func() error { + err := errors.New("first error") + err = errors.Trace(err) + err = errors.Annotate(err, "some context") + err = errors.Maskf(err, "masked") + err = errors.Annotate(err, "more context") + return errors.Trace(err) + }, + expected: "more context: masked: some context: first error", + }, + } { + c.Logf("%v: %s", i, test.message) + err := test.generator() + ok := c.Check(err.Error(), gc.Equals, test.expected) + if !ok { + c.Logf("%#v", test.generator()) + } + } +} + +type embed struct { + errors.Err +} + +func newEmbed(format string, args ...interface{}) *embed { + err := &embed{errors.NewErr(format, args...)} + err.SetLocation(1) + return err +} + +func (*errorsSuite) TestNewErr(c *gc.C) { + if runtime.Compiler == "gccgo" { + c.Skip("gccgo can't determine the location") + } + err := newEmbed("testing %d", 42) //err embedErr + c.Assert(err.Error(), gc.Equals, "testing 42") + c.Assert(errors.Cause(err), gc.Equals, err) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["embedErr"].String()) +} + +var _ error = (*embed)(nil) + +// This is an uncomparable error type, as it is a struct that supports the +// error interface (as opposed to a pointer type). +type error_ struct { + info string + slice []string +} + +// Create a non-comparable error +func newNonComparableError(message string) error { + return error_{info: message} +} + +func (e error_) Error() string { + return e.info +} + +func newError(message string) error { + return testError{message} +} + +// The testError is a value type error for ease of seeing results +// when the test fails. +type testError struct { + message string +} + +func (e testError) Error() string { + return e.message +} diff --git a/Godeps/_workspace/src/github.com/juju/errors/errortypes.go b/Godeps/_workspace/src/github.com/juju/errors/errortypes.go new file mode 100644 index 000000000..aa970e99c --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/errortypes.go @@ -0,0 +1,235 @@ +// Copyright 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors + +import ( + "fmt" +) + +// wrap is a helper to construct an *wrapper. +func wrap(err error, format, suffix string, args ...interface{}) Err { + newErr := Err{ + message: fmt.Sprintf(format+suffix, args...), + previous: err, + } + newErr.SetLocation(2) + return newErr +} + +// notFound represents an error when something has not been found. +type notFound struct { + Err +} + +// NotFoundf returns an error which satisfies IsNotFound(). +func NotFoundf(format string, args ...interface{}) error { + return ¬Found{wrap(nil, format, " not found", args...)} +} + +// NewNotFound returns an error which wraps err that satisfies +// IsNotFound(). +func NewNotFound(err error, msg string) error { + return ¬Found{wrap(err, msg, "")} +} + +// IsNotFound reports whether err was created with NotFoundf() or +// NewNotFound(). +func IsNotFound(err error) bool { + err = Cause(err) + _, ok := err.(*notFound) + return ok +} + +// userNotFound represents an error when an inexistent user is looked up. +type userNotFound struct { + Err +} + +// UserNotFoundf returns an error which satisfies IsUserNotFound(). +func UserNotFoundf(format string, args ...interface{}) error { + return &userNotFound{wrap(nil, format, " user not found", args...)} +} + +// NewUserNotFound returns an error which wraps err and satisfies +// IsUserNotFound(). +func NewUserNotFound(err error, msg string) error { + return &userNotFound{wrap(err, msg, "")} +} + +// IsUserNotFound reports whether err was created with UserNotFoundf() or +// NewUserNotFound(). +func IsUserNotFound(err error) bool { + err = Cause(err) + _, ok := err.(*userNotFound) + return ok +} + +// unauthorized represents an error when an operation is unauthorized. +type unauthorized struct { + Err +} + +// Unauthorizedf returns an error which satisfies IsUnauthorized(). +func Unauthorizedf(format string, args ...interface{}) error { + return &unauthorized{wrap(nil, format, "", args...)} +} + +// NewUnauthorized returns an error which wraps err and satisfies +// IsUnauthorized(). +func NewUnauthorized(err error, msg string) error { + return &unauthorized{wrap(err, msg, "")} +} + +// IsUnauthorized reports whether err was created with Unauthorizedf() or +// NewUnauthorized(). +func IsUnauthorized(err error) bool { + err = Cause(err) + _, ok := err.(*unauthorized) + return ok +} + +// notImplemented represents an error when something is not +// implemented. +type notImplemented struct { + Err +} + +// NotImplementedf returns an error which satisfies IsNotImplemented(). +func NotImplementedf(format string, args ...interface{}) error { + return ¬Implemented{wrap(nil, format, " not implemented", args...)} +} + +// NewNotImplemented returns an error which wraps err and satisfies +// IsNotImplemented(). +func NewNotImplemented(err error, msg string) error { + return ¬Implemented{wrap(err, msg, "")} +} + +// IsNotImplemented reports whether err was created with +// NotImplementedf() or NewNotImplemented(). +func IsNotImplemented(err error) bool { + err = Cause(err) + _, ok := err.(*notImplemented) + return ok +} + +// alreadyExists represents and error when something already exists. +type alreadyExists struct { + Err +} + +// AlreadyExistsf returns an error which satisfies IsAlreadyExists(). +func AlreadyExistsf(format string, args ...interface{}) error { + return &alreadyExists{wrap(nil, format, " already exists", args...)} +} + +// NewAlreadyExists returns an error which wraps err and satisfies +// IsAlreadyExists(). +func NewAlreadyExists(err error, msg string) error { + return &alreadyExists{wrap(err, msg, "")} +} + +// IsAlreadyExists reports whether the error was created with +// AlreadyExistsf() or NewAlreadyExists(). +func IsAlreadyExists(err error) bool { + err = Cause(err) + _, ok := err.(*alreadyExists) + return ok +} + +// notSupported represents an error when something is not supported. +type notSupported struct { + Err +} + +// NotSupportedf returns an error which satisfies IsNotSupported(). +func NotSupportedf(format string, args ...interface{}) error { + return ¬Supported{wrap(nil, format, " not supported", args...)} +} + +// NewNotSupported returns an error which wraps err and satisfies +// IsNotSupported(). +func NewNotSupported(err error, msg string) error { + return ¬Supported{wrap(err, msg, "")} +} + +// IsNotSupported reports whether the error was created with +// NotSupportedf() or NewNotSupported(). +func IsNotSupported(err error) bool { + err = Cause(err) + _, ok := err.(*notSupported) + return ok +} + +// notValid represents an error when something is not valid. +type notValid struct { + Err +} + +// NotValidf returns an error which satisfies IsNotValid(). +func NotValidf(format string, args ...interface{}) error { + return ¬Valid{wrap(nil, format, " not valid", args...)} +} + +// NewNotValid returns an error which wraps err and satisfies IsNotValid(). +func NewNotValid(err error, msg string) error { + return ¬Valid{wrap(err, msg, "")} +} + +// IsNotValid reports whether the error was created with NotValidf() or +// NewNotValid(). +func IsNotValid(err error) bool { + err = Cause(err) + _, ok := err.(*notValid) + return ok +} + +// notProvisioned represents an error when something is not yet provisioned. +type notProvisioned struct { + Err +} + +// NotProvisionedf returns an error which satisfies IsNotProvisioned(). +func NotProvisionedf(format string, args ...interface{}) error { + return ¬Provisioned{wrap(nil, format, " not provisioned", args...)} +} + +// NewNotProvisioned returns an error which wraps err that satisfies +// IsNotProvisioned(). +func NewNotProvisioned(err error, msg string) error { + return ¬Provisioned{wrap(err, msg, "")} +} + +// IsNotProvisioned reports whether err was created with NotProvisionedf() or +// NewNotProvisioned(). +func IsNotProvisioned(err error) bool { + err = Cause(err) + _, ok := err.(*notProvisioned) + return ok +} + +// notAssigned represents an error when something is not yet assigned to +// something else. +type notAssigned struct { + Err +} + +// NotAssignedf returns an error which satisfies IsNotAssigned(). +func NotAssignedf(format string, args ...interface{}) error { + return ¬Assigned{wrap(nil, format, " not assigned", args...)} +} + +// NewNotAssigned returns an error which wraps err that satisfies +// IsNotAssigned(). +func NewNotAssigned(err error, msg string) error { + return ¬Assigned{wrap(err, msg, "")} +} + +// IsNotAssigned reports whether err was created with NotAssignedf() or +// NewNotAssigned(). +func IsNotAssigned(err error) bool { + err = Cause(err) + _, ok := err.(*notAssigned) + return ok +} diff --git a/Godeps/_workspace/src/github.com/juju/errors/errortypes_test.go b/Godeps/_workspace/src/github.com/juju/errors/errortypes_test.go new file mode 100644 index 000000000..943772b5d --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/errortypes_test.go @@ -0,0 +1,171 @@ +// Copyright 2013, 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors_test + +import ( + stderrors "errors" + "fmt" + "reflect" + "runtime" + + "github.com/juju/errors" + jc "github.com/juju/testing/checkers" + gc "gopkg.in/check.v1" +) + +// errorInfo holds information about a single error type: a satisfier +// function, wrapping and variable arguments constructors and message +// suffix. +type errorInfo struct { + satisfier func(error) bool + argsConstructor func(string, ...interface{}) error + wrapConstructor func(error, string) error + suffix string +} + +// allErrors holds information for all defined errors. When adding new +// errors, add them here as well to include them in tests. +var allErrors = []*errorInfo{ + &errorInfo{errors.IsNotFound, errors.NotFoundf, errors.NewNotFound, " not found"}, + &errorInfo{errors.IsUserNotFound, errors.UserNotFoundf, errors.NewUserNotFound, " user not found"}, + &errorInfo{errors.IsUnauthorized, errors.Unauthorizedf, errors.NewUnauthorized, ""}, + &errorInfo{errors.IsNotImplemented, errors.NotImplementedf, errors.NewNotImplemented, " not implemented"}, + &errorInfo{errors.IsAlreadyExists, errors.AlreadyExistsf, errors.NewAlreadyExists, " already exists"}, + &errorInfo{errors.IsNotSupported, errors.NotSupportedf, errors.NewNotSupported, " not supported"}, + &errorInfo{errors.IsNotValid, errors.NotValidf, errors.NewNotValid, " not valid"}, + &errorInfo{errors.IsNotProvisioned, errors.NotProvisionedf, errors.NewNotProvisioned, " not provisioned"}, + &errorInfo{errors.IsNotAssigned, errors.NotAssignedf, errors.NewNotAssigned, " not assigned"}, +} + +type errorTypeSuite struct{} + +var _ = gc.Suite(&errorTypeSuite{}) + +func (t *errorInfo) satisfierName() string { + value := reflect.ValueOf(t.satisfier) + f := runtime.FuncForPC(value.Pointer()) + return f.Name() +} + +func (t *errorInfo) equal(t0 *errorInfo) bool { + if t0 == nil { + return false + } + return t.satisfierName() == t0.satisfierName() +} + +type errorTest struct { + err error + message string + errInfo *errorInfo +} + +func deferredAnnotatef(err error, format string, args ...interface{}) error { + errors.DeferredAnnotatef(&err, format, args...) + return err +} + +func mustSatisfy(c *gc.C, err error, errInfo *errorInfo) { + if errInfo != nil { + msg := fmt.Sprintf("%#v must satisfy %v", err, errInfo.satisfierName()) + c.Check(err, jc.Satisfies, errInfo.satisfier, gc.Commentf(msg)) + } +} + +func mustNotSatisfy(c *gc.C, err error, errInfo *errorInfo) { + if errInfo != nil { + msg := fmt.Sprintf("%#v must not satisfy %v", err, errInfo.satisfierName()) + c.Check(err, gc.Not(jc.Satisfies), errInfo.satisfier, gc.Commentf(msg)) + } +} + +func checkErrorMatches(c *gc.C, err error, message string, errInfo *errorInfo) { + if message == "" { + c.Check(err, gc.IsNil) + c.Check(errInfo, gc.IsNil) + } else { + c.Check(err, gc.ErrorMatches, message) + } +} + +func runErrorTests(c *gc.C, errorTests []errorTest, checkMustSatisfy bool) { + for i, t := range errorTests { + c.Logf("test %d: %T: %v", i, t.err, t.err) + checkErrorMatches(c, t.err, t.message, t.errInfo) + if checkMustSatisfy { + mustSatisfy(c, t.err, t.errInfo) + } + + // Check all other satisfiers to make sure none match. + for _, otherErrInfo := range allErrors { + if checkMustSatisfy && otherErrInfo.equal(t.errInfo) { + continue + } + mustNotSatisfy(c, t.err, otherErrInfo) + } + } +} + +func (*errorTypeSuite) TestDeferredAnnotatef(c *gc.C) { + // Ensure DeferredAnnotatef annotates the errors. + errorTests := []errorTest{} + for _, errInfo := range allErrors { + errorTests = append(errorTests, []errorTest{{ + deferredAnnotatef(nil, "comment"), + "", + nil, + }, { + deferredAnnotatef(stderrors.New("blast"), "comment"), + "comment: blast", + nil, + }, { + deferredAnnotatef(errInfo.argsConstructor("foo %d", 42), "comment %d", 69), + "comment 69: foo 42" + errInfo.suffix, + errInfo, + }, { + deferredAnnotatef(errInfo.argsConstructor(""), "comment"), + "comment: " + errInfo.suffix, + errInfo, + }, { + deferredAnnotatef(errInfo.wrapConstructor(stderrors.New("pow!"), "woo"), "comment"), + "comment: woo: pow!", + errInfo, + }}...) + } + + runErrorTests(c, errorTests, true) +} + +func (*errorTypeSuite) TestAllErrors(c *gc.C) { + errorTests := []errorTest{} + for _, errInfo := range allErrors { + errorTests = append(errorTests, []errorTest{{ + nil, + "", + nil, + }, { + errInfo.argsConstructor("foo %d", 42), + "foo 42" + errInfo.suffix, + errInfo, + }, { + errInfo.argsConstructor(""), + errInfo.suffix, + errInfo, + }, { + errInfo.wrapConstructor(stderrors.New("pow!"), "prefix"), + "prefix: pow!", + errInfo, + }, { + errInfo.wrapConstructor(stderrors.New("pow!"), ""), + "pow!", + errInfo, + }, { + errInfo.wrapConstructor(nil, "prefix"), + "prefix", + errInfo, + }}...) + } + + runErrorTests(c, errorTests, true) +} diff --git a/Godeps/_workspace/src/github.com/juju/errors/example_test.go b/Godeps/_workspace/src/github.com/juju/errors/example_test.go new file mode 100644 index 000000000..2a79cf489 --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/example_test.go @@ -0,0 +1,23 @@ +// Copyright 2013, 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors_test + +import ( + "fmt" + + "github.com/juju/errors" +) + +func ExampleTrace() { + var err1 error = fmt.Errorf("something wicked this way comes") + var err2 error = nil + + // Tracing a non nil error will return an error + fmt.Println(errors.Trace(err1)) + // Tracing nil will return nil + fmt.Println(errors.Trace(err2)) + + // Output: something wicked this way comes + // +} diff --git a/Godeps/_workspace/src/github.com/juju/errors/export_test.go b/Godeps/_workspace/src/github.com/juju/errors/export_test.go new file mode 100644 index 000000000..db57ec81c --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/export_test.go @@ -0,0 +1,12 @@ +// Copyright 2013, 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors + +// Since variables are declared before the init block, in order to get the goPath +// we need to return it rather than just reference it. +func GoPath() string { + return goPath +} + +var TrimGoPath = trimGoPath diff --git a/Godeps/_workspace/src/github.com/juju/errors/functions.go b/Godeps/_workspace/src/github.com/juju/errors/functions.go new file mode 100644 index 000000000..994208d8d --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/functions.go @@ -0,0 +1,330 @@ +// Copyright 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors + +import ( + "fmt" + "strings" +) + +// New is a drop in replacement for the standard libary errors module that records +// the location that the error is created. +// +// For example: +// return errors.New("validation failed") +// +func New(message string) error { + err := &Err{message: message} + err.SetLocation(1) + return err +} + +// Errorf creates a new annotated error and records the location that the +// error is created. This should be a drop in replacement for fmt.Errorf. +// +// For example: +// return errors.Errorf("validation failed: %s", message) +// +func Errorf(format string, args ...interface{}) error { + err := &Err{message: fmt.Sprintf(format, args...)} + err.SetLocation(1) + return err +} + +// Trace adds the location of the Trace call to the stack. The Cause of the +// resulting error is the same as the error parameter. If the other error is +// nil, the result will be nil. +// +// For example: +// if err := SomeFunc(); err != nil { +// return errors.Trace(err) +// } +// +func Trace(other error) error { + if other == nil { + return nil + } + err := &Err{previous: other, cause: Cause(other)} + err.SetLocation(1) + return err +} + +// Annotate is used to add extra context to an existing error. The location of +// the Annotate call is recorded with the annotations. The file, line and +// function are also recorded. +// +// For example: +// if err := SomeFunc(); err != nil { +// return errors.Annotate(err, "failed to frombulate") +// } +// +func Annotate(other error, message string) error { + if other == nil { + return nil + } + err := &Err{ + previous: other, + cause: Cause(other), + message: message, + } + err.SetLocation(1) + return err +} + +// Annotatef is used to add extra context to an existing error. The location of +// the Annotate call is recorded with the annotations. The file, line and +// function are also recorded. +// +// For example: +// if err := SomeFunc(); err != nil { +// return errors.Annotatef(err, "failed to frombulate the %s", arg) +// } +// +func Annotatef(other error, format string, args ...interface{}) error { + if other == nil { + return nil + } + err := &Err{ + previous: other, + cause: Cause(other), + message: fmt.Sprintf(format, args...), + } + err.SetLocation(1) + return err +} + +// DeferredAnnotatef annotates the given error (when it is not nil) with the given +// format string and arguments (like fmt.Sprintf). If *err is nil, DeferredAnnotatef +// does nothing. This method is used in a defer statement in order to annotate any +// resulting error with the same message. +// +// For example: +// +// defer DeferredAnnotatef(&err, "failed to frombulate the %s", arg) +// +func DeferredAnnotatef(err *error, format string, args ...interface{}) { + if *err == nil { + return + } + newErr := &Err{ + message: fmt.Sprintf(format, args...), + cause: Cause(*err), + previous: *err, + } + newErr.SetLocation(1) + *err = newErr +} + +// Wrap changes the Cause of the error. The location of the Wrap call is also +// stored in the error stack. +// +// For example: +// if err := SomeFunc(); err != nil { +// newErr := &packageError{"more context", private_value} +// return errors.Wrap(err, newErr) +// } +// +func Wrap(other, newDescriptive error) error { + err := &Err{ + previous: other, + cause: newDescriptive, + } + err.SetLocation(1) + return err +} + +// Wrapf changes the Cause of the error, and adds an annotation. The location +// of the Wrap call is also stored in the error stack. +// +// For example: +// if err := SomeFunc(); err != nil { +// return errors.Wrapf(err, simpleErrorType, "invalid value %q", value) +// } +// +func Wrapf(other, newDescriptive error, format string, args ...interface{}) error { + err := &Err{ + message: fmt.Sprintf(format, args...), + previous: other, + cause: newDescriptive, + } + err.SetLocation(1) + return err +} + +// Mask masks the given error with the given format string and arguments (like +// fmt.Sprintf), returning a new error that maintains the error stack, but +// hides the underlying error type. The error string still contains the full +// annotations. If you want to hide the annotations, call Wrap. +func Maskf(other error, format string, args ...interface{}) error { + if other == nil { + return nil + } + err := &Err{ + message: fmt.Sprintf(format, args...), + previous: other, + } + err.SetLocation(1) + return err +} + +// Mask hides the underlying error type, and records the location of the masking. +func Mask(other error) error { + if other == nil { + return nil + } + err := &Err{ + previous: other, + } + err.SetLocation(1) + return err +} + +// Cause returns the cause of the given error. This will be either the +// original error, or the result of a Wrap or Mask call. +// +// Cause is the usual way to diagnose errors that may have been wrapped by +// the other errors functions. +func Cause(err error) error { + var diag error + if err, ok := err.(causer); ok { + diag = err.Cause() + } + if diag != nil { + return diag + } + return err +} + +type causer interface { + Cause() error +} + +type wrapper interface { + // Message returns the top level error message, + // not including the message from the Previous + // error. + Message() string + + // Underlying returns the Previous error, or nil + // if there is none. + Underlying() error +} + +type locationer interface { + Location() (string, int) +} + +var ( + _ wrapper = (*Err)(nil) + _ locationer = (*Err)(nil) + _ causer = (*Err)(nil) +) + +// Details returns information about the stack of errors wrapped by err, in +// the format: +// +// [{filename:99: error one} {otherfile:55: cause of error one}] +// +// This is a terse alternative to ErrorStack as it returns a single line. +func Details(err error) string { + if err == nil { + return "[]" + } + var s []byte + s = append(s, '[') + for { + s = append(s, '{') + if err, ok := err.(locationer); ok { + file, line := err.Location() + if file != "" { + s = append(s, fmt.Sprintf("%s:%d", file, line)...) + s = append(s, ": "...) + } + } + if cerr, ok := err.(wrapper); ok { + s = append(s, cerr.Message()...) + err = cerr.Underlying() + } else { + s = append(s, err.Error()...) + err = nil + } + s = append(s, '}') + if err == nil { + break + } + s = append(s, ' ') + } + s = append(s, ']') + return string(s) +} + +// ErrorStack returns a string representation of the annotated error. If the +// error passed as the parameter is not an annotated error, the result is +// simply the result of the Error() method on that error. +// +// If the error is an annotated error, a multi-line string is returned where +// each line represents one entry in the annotation stack. The full filename +// from the call stack is used in the output. +// +// first error +// github.com/juju/errors/annotation_test.go:193: +// github.com/juju/errors/annotation_test.go:194: annotation +// github.com/juju/errors/annotation_test.go:195: +// github.com/juju/errors/annotation_test.go:196: more context +// github.com/juju/errors/annotation_test.go:197: +func ErrorStack(err error) string { + return strings.Join(errorStack(err), "\n") +} + +func errorStack(err error) []string { + if err == nil { + return nil + } + + // We want the first error first + var lines []string + for { + var buff []byte + if err, ok := err.(locationer); ok { + file, line := err.Location() + // Strip off the leading GOPATH/src path elements. + file = trimGoPath(file) + if file != "" { + buff = append(buff, fmt.Sprintf("%s:%d", file, line)...) + buff = append(buff, ": "...) + } + } + if cerr, ok := err.(wrapper); ok { + message := cerr.Message() + buff = append(buff, message...) + // If there is a cause for this error, and it is different to the cause + // of the underlying error, then output the error string in the stack trace. + var cause error + if err1, ok := err.(causer); ok { + cause = err1.Cause() + } + err = cerr.Underlying() + if cause != nil && !sameError(Cause(err), cause) { + if message != "" { + buff = append(buff, ": "...) + } + buff = append(buff, cause.Error()...) + } + } else { + buff = append(buff, err.Error()...) + err = nil + } + lines = append(lines, string(buff)) + if err == nil { + break + } + } + // reverse the lines to get the original error, which was at the end of + // the list, back to the start. + var result []string + for i := len(lines); i > 0; i-- { + result = append(result, lines[i-1]) + } + return result +} diff --git a/Godeps/_workspace/src/github.com/juju/errors/functions_test.go b/Godeps/_workspace/src/github.com/juju/errors/functions_test.go new file mode 100644 index 000000000..7b1e43b15 --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/functions_test.go @@ -0,0 +1,305 @@ +// Copyright 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors_test + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + + jc "github.com/juju/testing/checkers" + gc "gopkg.in/check.v1" + + "github.com/juju/errors" +) + +type functionSuite struct { +} + +var _ = gc.Suite(&functionSuite{}) + +func (*functionSuite) TestNew(c *gc.C) { + err := errors.New("testing") //err newTest + c.Assert(err.Error(), gc.Equals, "testing") + c.Assert(errors.Cause(err), gc.Equals, err) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["newTest"].String()) +} + +func (*functionSuite) TestErrorf(c *gc.C) { + err := errors.Errorf("testing %d", 42) //err errorfTest + c.Assert(err.Error(), gc.Equals, "testing 42") + c.Assert(errors.Cause(err), gc.Equals, err) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["errorfTest"].String()) +} + +func (*functionSuite) TestTrace(c *gc.C) { + first := errors.New("first") + err := errors.Trace(first) //err traceTest + c.Assert(err.Error(), gc.Equals, "first") + c.Assert(errors.Cause(err), gc.Equals, first) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["traceTest"].String()) + + c.Assert(errors.Trace(nil), gc.IsNil) +} + +func (*functionSuite) TestAnnotate(c *gc.C) { + first := errors.New("first") + err := errors.Annotate(first, "annotation") //err annotateTest + c.Assert(err.Error(), gc.Equals, "annotation: first") + c.Assert(errors.Cause(err), gc.Equals, first) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["annotateTest"].String()) + + c.Assert(errors.Annotate(nil, "annotate"), gc.IsNil) +} + +func (*functionSuite) TestAnnotatef(c *gc.C) { + first := errors.New("first") + err := errors.Annotatef(first, "annotation %d", 2) //err annotatefTest + c.Assert(err.Error(), gc.Equals, "annotation 2: first") + c.Assert(errors.Cause(err), gc.Equals, first) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["annotatefTest"].String()) + + c.Assert(errors.Annotatef(nil, "annotate"), gc.IsNil) +} + +func (*functionSuite) TestDeferredAnnotatef(c *gc.C) { + // NOTE: this test fails with gccgo + if runtime.Compiler == "gccgo" { + c.Skip("gccgo can't determine the location") + } + first := errors.New("first") + test := func() (err error) { + defer errors.DeferredAnnotatef(&err, "deferred %s", "annotate") + return first + } //err deferredAnnotate + err := test() + c.Assert(err.Error(), gc.Equals, "deferred annotate: first") + c.Assert(errors.Cause(err), gc.Equals, first) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["deferredAnnotate"].String()) + + err = nil + errors.DeferredAnnotatef(&err, "deferred %s", "annotate") + c.Assert(err, gc.IsNil) +} + +func (*functionSuite) TestWrap(c *gc.C) { + first := errors.New("first") //err wrapFirst + detailed := errors.New("detailed") + err := errors.Wrap(first, detailed) //err wrapTest + c.Assert(err.Error(), gc.Equals, "detailed") + c.Assert(errors.Cause(err), gc.Equals, detailed) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["wrapFirst"].String()) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["wrapTest"].String()) +} + +func (*functionSuite) TestWrapOfNil(c *gc.C) { + detailed := errors.New("detailed") + err := errors.Wrap(nil, detailed) //err nilWrapTest + c.Assert(err.Error(), gc.Equals, "detailed") + c.Assert(errors.Cause(err), gc.Equals, detailed) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["nilWrapTest"].String()) +} + +func (*functionSuite) TestWrapf(c *gc.C) { + first := errors.New("first") //err wrapfFirst + detailed := errors.New("detailed") + err := errors.Wrapf(first, detailed, "value %d", 42) //err wrapfTest + c.Assert(err.Error(), gc.Equals, "value 42: detailed") + c.Assert(errors.Cause(err), gc.Equals, detailed) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["wrapfFirst"].String()) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["wrapfTest"].String()) +} + +func (*functionSuite) TestWrapfOfNil(c *gc.C) { + detailed := errors.New("detailed") + err := errors.Wrapf(nil, detailed, "value %d", 42) //err nilWrapfTest + c.Assert(err.Error(), gc.Equals, "value 42: detailed") + c.Assert(errors.Cause(err), gc.Equals, detailed) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["nilWrapfTest"].String()) +} + +func (*functionSuite) TestMask(c *gc.C) { + first := errors.New("first") + err := errors.Mask(first) //err maskTest + c.Assert(err.Error(), gc.Equals, "first") + c.Assert(errors.Cause(err), gc.Equals, err) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["maskTest"].String()) + + c.Assert(errors.Mask(nil), gc.IsNil) +} + +func (*functionSuite) TestMaskf(c *gc.C) { + first := errors.New("first") + err := errors.Maskf(first, "masked %d", 42) //err maskfTest + c.Assert(err.Error(), gc.Equals, "masked 42: first") + c.Assert(errors.Cause(err), gc.Equals, err) + c.Assert(errors.Details(err), jc.Contains, tagToLocation["maskfTest"].String()) + + c.Assert(errors.Maskf(nil, "mask"), gc.IsNil) +} + +func (*functionSuite) TestCause(c *gc.C) { + c.Assert(errors.Cause(nil), gc.IsNil) + c.Assert(errors.Cause(someErr), gc.Equals, someErr) + + fmtErr := fmt.Errorf("simple") + c.Assert(errors.Cause(fmtErr), gc.Equals, fmtErr) + + err := errors.Wrap(someErr, fmtErr) + c.Assert(errors.Cause(err), gc.Equals, fmtErr) + + err = errors.Annotate(err, "annotated") + c.Assert(errors.Cause(err), gc.Equals, fmtErr) + + err = errors.Maskf(err, "maksed") + c.Assert(errors.Cause(err), gc.Equals, err) + + // Look for a file that we know isn't there. + dir := c.MkDir() + _, err = os.Stat(filepath.Join(dir, "not-there")) + c.Assert(os.IsNotExist(err), jc.IsTrue) + + err = errors.Annotatef(err, "wrap it") + // Now the error itself isn't a 'IsNotExist'. + c.Assert(os.IsNotExist(err), jc.IsFalse) + // However if we use the Check method, it is. + c.Assert(os.IsNotExist(errors.Cause(err)), jc.IsTrue) +} + +func (s *functionSuite) TestDetails(c *gc.C) { + if runtime.Compiler == "gccgo" { + c.Skip("gccgo can't determine the location") + } + c.Assert(errors.Details(nil), gc.Equals, "[]") + + otherErr := fmt.Errorf("other") + checkDetails(c, otherErr, "[{other}]") + + err0 := newEmbed("foo") //err TestStack#0 + checkDetails(c, err0, "[{$TestStack#0$: foo}]") + + err1 := errors.Annotate(err0, "bar") //err TestStack#1 + checkDetails(c, err1, "[{$TestStack#1$: bar} {$TestStack#0$: foo}]") + + err2 := errors.Trace(err1) //err TestStack#2 + checkDetails(c, err2, "[{$TestStack#2$: } {$TestStack#1$: bar} {$TestStack#0$: foo}]") +} + +type tracer interface { + StackTrace() []string +} + +func (*functionSuite) TestErrorStack(c *gc.C) { + for i, test := range []struct { + message string + generator func() error + expected string + tracer bool + }{ + { + message: "nil", + generator: func() error { + return nil + }, + }, { + message: "raw error", + generator: func() error { + return fmt.Errorf("raw") + }, + expected: "raw", + }, { + message: "single error stack", + generator: func() error { + return errors.New("first error") //err single + }, + expected: "$single$: first error", + tracer: true, + }, { + message: "annotated error", + generator: func() error { + err := errors.New("first error") //err annotated-0 + return errors.Annotate(err, "annotation") //err annotated-1 + }, + expected: "" + + "$annotated-0$: first error\n" + + "$annotated-1$: annotation", + tracer: true, + }, { + message: "wrapped error", + generator: func() error { + err := errors.New("first error") //err wrapped-0 + return errors.Wrap(err, newError("detailed error")) //err wrapped-1 + }, + expected: "" + + "$wrapped-0$: first error\n" + + "$wrapped-1$: detailed error", + tracer: true, + }, { + message: "annotated wrapped error", + generator: func() error { + err := errors.Errorf("first error") //err ann-wrap-0 + err = errors.Wrap(err, fmt.Errorf("detailed error")) //err ann-wrap-1 + return errors.Annotatef(err, "annotated") //err ann-wrap-2 + }, + expected: "" + + "$ann-wrap-0$: first error\n" + + "$ann-wrap-1$: detailed error\n" + + "$ann-wrap-2$: annotated", + tracer: true, + }, { + message: "traced, and annotated", + generator: func() error { + err := errors.New("first error") //err stack-0 + err = errors.Trace(err) //err stack-1 + err = errors.Annotate(err, "some context") //err stack-2 + err = errors.Trace(err) //err stack-3 + err = errors.Annotate(err, "more context") //err stack-4 + return errors.Trace(err) //err stack-5 + }, + expected: "" + + "$stack-0$: first error\n" + + "$stack-1$: \n" + + "$stack-2$: some context\n" + + "$stack-3$: \n" + + "$stack-4$: more context\n" + + "$stack-5$: ", + tracer: true, + }, { + message: "uncomparable, wrapped with a value error", + generator: func() error { + err := newNonComparableError("first error") //err mixed-0 + err = errors.Trace(err) //err mixed-1 + err = errors.Wrap(err, newError("value error")) //err mixed-2 + err = errors.Maskf(err, "masked") //err mixed-3 + err = errors.Annotate(err, "more context") //err mixed-4 + return errors.Trace(err) //err mixed-5 + }, + expected: "" + + "first error\n" + + "$mixed-1$: \n" + + "$mixed-2$: value error\n" + + "$mixed-3$: masked\n" + + "$mixed-4$: more context\n" + + "$mixed-5$: ", + tracer: true, + }, + } { + c.Logf("%v: %s", i, test.message) + err := test.generator() + expected := replaceLocations(test.expected) + stack := errors.ErrorStack(err) + ok := c.Check(stack, gc.Equals, expected) + if !ok { + c.Logf("%#v", err) + } + tracer, ok := err.(tracer) + c.Check(ok, gc.Equals, test.tracer) + if ok { + stackTrace := tracer.StackTrace() + c.Check(stackTrace, gc.DeepEquals, strings.Split(stack, "\n")) + } + } +} diff --git a/Godeps/_workspace/src/github.com/juju/errors/package_test.go b/Godeps/_workspace/src/github.com/juju/errors/package_test.go new file mode 100644 index 000000000..5bbb8f04e --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/package_test.go @@ -0,0 +1,95 @@ +// Copyright 2013, 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors_test + +import ( + "fmt" + "io/ioutil" + "strings" + "testing" + + gc "gopkg.in/check.v1" + + "github.com/juju/errors" +) + +func Test(t *testing.T) { + gc.TestingT(t) +} + +func checkDetails(c *gc.C, err error, details string) { + c.Assert(err, gc.NotNil) + expectedDetails := replaceLocations(details) + c.Assert(errors.Details(err), gc.Equals, expectedDetails) +} + +func checkErr(c *gc.C, err, cause error, msg string, details string) { + c.Assert(err, gc.NotNil) + c.Assert(err.Error(), gc.Equals, msg) + c.Assert(errors.Cause(err), gc.Equals, cause) + expectedDetails := replaceLocations(details) + c.Assert(errors.Details(err), gc.Equals, expectedDetails) +} + +func replaceLocations(line string) string { + result := "" + for { + i := strings.Index(line, "$") + if i == -1 { + break + } + result += line[0:i] + line = line[i+1:] + i = strings.Index(line, "$") + if i == -1 { + panic("no second $") + } + result += location(line[0:i]).String() + line = line[i+1:] + } + result += line + return result +} + +func location(tag string) Location { + loc, ok := tagToLocation[tag] + if !ok { + panic(fmt.Sprintf("tag %q not found", tag)) + } + return loc +} + +type Location struct { + file string + line int +} + +func (loc Location) String() string { + return fmt.Sprintf("%s:%d", loc.file, loc.line) +} + +var tagToLocation = make(map[string]Location) + +func setLocationsForErrorTags(filename string) { + data, err := ioutil.ReadFile(filename) + if err != nil { + panic(err) + } + filename = "github.com/juju/errors/" + filename + lines := strings.Split(string(data), "\n") + for i, line := range lines { + if j := strings.Index(line, "//err "); j >= 0 { + tag := line[j+len("//err "):] + if _, found := tagToLocation[tag]; found { + panic(fmt.Sprintf("tag %q already processed previously", tag)) + } + tagToLocation[tag] = Location{file: filename, line: i + 1} + } + } +} + +func init() { + setLocationsForErrorTags("error_test.go") + setLocationsForErrorTags("functions_test.go") +} diff --git a/Godeps/_workspace/src/github.com/juju/errors/path.go b/Godeps/_workspace/src/github.com/juju/errors/path.go new file mode 100644 index 000000000..3ec517c81 --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/path.go @@ -0,0 +1,35 @@ +// Copyright 2013, 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors + +import ( + "runtime" + "strings" +) + +// prefixSize is used internally to trim the user specific path from the +// front of the returned filenames from the runtime call stack. +var prefixSize int + +// goPath is the deduced path based on the location of this file as compiled. +var goPath string + +func init() { + _, file, _, ok := runtime.Caller(0) + if ok { + // We know that the end of the file should be: + // github.com/juju/errors/path.go + size := len(file) + suffix := len("github.com/juju/errors/path.go") + goPath = file[:size-suffix] + prefixSize = len(goPath) + } +} + +func trimGoPath(filename string) string { + if strings.HasPrefix(filename, goPath) { + return filename[prefixSize:] + } + return filename +} diff --git a/Godeps/_workspace/src/github.com/juju/errors/path_test.go b/Godeps/_workspace/src/github.com/juju/errors/path_test.go new file mode 100644 index 000000000..ef4f34f82 --- /dev/null +++ b/Godeps/_workspace/src/github.com/juju/errors/path_test.go @@ -0,0 +1,29 @@ +// Copyright 2013, 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors_test + +import ( + "path" + + gc "gopkg.in/check.v1" + + "github.com/juju/errors" +) + +type pathSuite struct{} + +var _ = gc.Suite(&pathSuite{}) + +func (*pathSuite) TestGoPathSet(c *gc.C) { + c.Assert(errors.GoPath(), gc.Not(gc.Equals), "") +} + +func (*pathSuite) TestTrimGoPath(c *gc.C) { + relativeImport := "github.com/foo/bar/baz.go" + filename := path.Join(errors.GoPath(), relativeImport) + c.Assert(errors.TrimGoPath(filename), gc.Equals, relativeImport) + + absoluteImport := "/usr/share/foo/bar/baz.go" + c.Assert(errors.TrimGoPath(absoluteImport), gc.Equals, absoluteImport) +} diff --git a/Godeps/_workspace/src/github.com/kr/fs/LICENSE b/Godeps/_workspace/src/github.com/kr/fs/LICENSE new file mode 100644 index 000000000..744875676 --- /dev/null +++ b/Godeps/_workspace/src/github.com/kr/fs/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2012 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Godeps/_workspace/src/github.com/kr/fs/Readme b/Godeps/_workspace/src/github.com/kr/fs/Readme new file mode 100644 index 000000000..c95e13fc8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/kr/fs/Readme @@ -0,0 +1,3 @@ +Filesystem Package + +http://godoc.org/github.com/kr/fs diff --git a/Godeps/_workspace/src/github.com/kr/fs/example_test.go b/Godeps/_workspace/src/github.com/kr/fs/example_test.go new file mode 100644 index 000000000..77e0db9fe --- /dev/null +++ b/Godeps/_workspace/src/github.com/kr/fs/example_test.go @@ -0,0 +1,19 @@ +package fs_test + +import ( + "fmt" + "os" + + "github.com/kr/fs" +) + +func ExampleWalker() { + walker := fs.Walk("/usr/lib") + for walker.Step() { + if err := walker.Err(); err != nil { + fmt.Fprintln(os.Stderr, err) + continue + } + fmt.Println(walker.Path()) + } +} diff --git a/Godeps/_workspace/src/github.com/kr/fs/filesystem.go b/Godeps/_workspace/src/github.com/kr/fs/filesystem.go new file mode 100644 index 000000000..f1c4805fb --- /dev/null +++ b/Godeps/_workspace/src/github.com/kr/fs/filesystem.go @@ -0,0 +1,36 @@ +package fs + +import ( + "io/ioutil" + "os" + "path/filepath" +) + +// FileSystem defines the methods of an abstract filesystem. +type FileSystem interface { + + // ReadDir reads the directory named by dirname and returns a + // list of directory entries. + ReadDir(dirname string) ([]os.FileInfo, error) + + // Lstat returns a FileInfo describing the named file. If the file is a + // symbolic link, the returned FileInfo describes the symbolic link. Lstat + // makes no attempt to follow the link. + Lstat(name string) (os.FileInfo, error) + + // Join joins any number of path elements into a single path, adding a + // separator if necessary. The result is Cleaned; in particular, all + // empty strings are ignored. + // + // The separator is FileSystem specific. + Join(elem ...string) string +} + +// fs represents a FileSystem provided by the os package. +type fs struct{} + +func (f *fs) ReadDir(dirname string) ([]os.FileInfo, error) { return ioutil.ReadDir(dirname) } + +func (f *fs) Lstat(name string) (os.FileInfo, error) { return os.Lstat(name) } + +func (f *fs) Join(elem ...string) string { return filepath.Join(elem...) } diff --git a/Godeps/_workspace/src/github.com/kr/fs/walk.go b/Godeps/_workspace/src/github.com/kr/fs/walk.go new file mode 100644 index 000000000..6ffa1e0b2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/kr/fs/walk.go @@ -0,0 +1,95 @@ +// Package fs provides filesystem-related functions. +package fs + +import ( + "os" +) + +// Walker provides a convenient interface for iterating over the +// descendants of a filesystem path. +// Successive calls to the Step method will step through each +// file or directory in the tree, including the root. The files +// are walked in lexical order, which makes the output deterministic +// but means that for very large directories Walker can be inefficient. +// Walker does not follow symbolic links. +type Walker struct { + fs FileSystem + cur item + stack []item + descend bool +} + +type item struct { + path string + info os.FileInfo + err error +} + +// Walk returns a new Walker rooted at root. +func Walk(root string) *Walker { + return WalkFS(root, new(fs)) +} + +// WalkFS returns a new Walker rooted at root on the FileSystem fs. +func WalkFS(root string, fs FileSystem) *Walker { + info, err := fs.Lstat(root) + return &Walker{ + fs: fs, + stack: []item{{root, info, err}}, + } +} + +// Step advances the Walker to the next file or directory, +// which will then be available through the Path, Stat, +// and Err methods. +// It returns false when the walk stops at the end of the tree. +func (w *Walker) Step() bool { + if w.descend && w.cur.err == nil && w.cur.info.IsDir() { + list, err := w.fs.ReadDir(w.cur.path) + if err != nil { + w.cur.err = err + w.stack = append(w.stack, w.cur) + } else { + for i := len(list) - 1; i >= 0; i-- { + path := w.fs.Join(w.cur.path, list[i].Name()) + w.stack = append(w.stack, item{path, list[i], nil}) + } + } + } + + if len(w.stack) == 0 { + return false + } + i := len(w.stack) - 1 + w.cur = w.stack[i] + w.stack = w.stack[:i] + w.descend = true + return true +} + +// Path returns the path to the most recent file or directory +// visited by a call to Step. It contains the argument to Walk +// as a prefix; that is, if Walk is called with "dir", which is +// a directory containing the file "a", Path will return "dir/a". +func (w *Walker) Path() string { + return w.cur.path +} + +// Stat returns info for the most recent file or directory +// visited by a call to Step. +func (w *Walker) Stat() os.FileInfo { + return w.cur.info +} + +// Err returns the error, if any, for the most recent attempt +// by Step to visit a file or directory. If a directory has +// an error, w will not descend into that directory. +func (w *Walker) Err() error { + return w.cur.err +} + +// SkipDir causes the currently visited directory to be skipped. +// If w is not on a directory, SkipDir has no effect. +func (w *Walker) SkipDir() { + w.descend = false +} diff --git a/Godeps/_workspace/src/github.com/kr/fs/walk_test.go b/Godeps/_workspace/src/github.com/kr/fs/walk_test.go new file mode 100644 index 000000000..6f5ad2ad3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/kr/fs/walk_test.go @@ -0,0 +1,209 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs_test + +import ( + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/kr/fs" +) + +type PathTest struct { + path, result string +} + +type Node struct { + name string + entries []*Node // nil if the entry is a file + mark int +} + +var tree = &Node{ + "testdata", + []*Node{ + {"a", nil, 0}, + {"b", []*Node{}, 0}, + {"c", nil, 0}, + { + "d", + []*Node{ + {"x", nil, 0}, + {"y", []*Node{}, 0}, + { + "z", + []*Node{ + {"u", nil, 0}, + {"v", nil, 0}, + }, + 0, + }, + }, + 0, + }, + }, + 0, +} + +func walkTree(n *Node, path string, f func(path string, n *Node)) { + f(path, n) + for _, e := range n.entries { + walkTree(e, filepath.Join(path, e.name), f) + } +} + +func makeTree(t *testing.T) { + walkTree(tree, tree.name, func(path string, n *Node) { + if n.entries == nil { + fd, err := os.Create(path) + if err != nil { + t.Errorf("makeTree: %v", err) + return + } + fd.Close() + } else { + os.Mkdir(path, 0770) + } + }) +} + +func markTree(n *Node) { walkTree(n, "", func(path string, n *Node) { n.mark++ }) } + +func checkMarks(t *testing.T, report bool) { + walkTree(tree, tree.name, func(path string, n *Node) { + if n.mark != 1 && report { + t.Errorf("node %s mark = %d; expected 1", path, n.mark) + } + n.mark = 0 + }) +} + +// Assumes that each node name is unique. Good enough for a test. +// If clear is true, any incoming error is cleared before return. The errors +// are always accumulated, though. +func mark(path string, info os.FileInfo, err error, errors *[]error, clear bool) error { + if err != nil { + *errors = append(*errors, err) + if clear { + return nil + } + return err + } + name := info.Name() + walkTree(tree, tree.name, func(path string, n *Node) { + if n.name == name { + n.mark++ + } + }) + return nil +} + +func TestWalk(t *testing.T) { + makeTree(t) + errors := make([]error, 0, 10) + clear := true + markFn := func(walker *fs.Walker) (err error) { + for walker.Step() { + err = mark(walker.Path(), walker.Stat(), walker.Err(), &errors, clear) + if err != nil { + break + } + } + return err + } + // Expect no errors. + err := markFn(fs.Walk(tree.name)) + if err != nil { + t.Fatalf("no error expected, found: %s", err) + } + if len(errors) != 0 { + t.Fatalf("unexpected errors: %s", errors) + } + checkMarks(t, true) + errors = errors[0:0] + + // Test permission errors. Only possible if we're not root + // and only on some file systems (AFS, FAT). To avoid errors during + // all.bash on those file systems, skip during go test -short. + if os.Getuid() > 0 && !testing.Short() { + // introduce 2 errors: chmod top-level directories to 0 + os.Chmod(filepath.Join(tree.name, tree.entries[1].name), 0) + os.Chmod(filepath.Join(tree.name, tree.entries[3].name), 0) + + // 3) capture errors, expect two. + // mark respective subtrees manually + markTree(tree.entries[1]) + markTree(tree.entries[3]) + // correct double-marking of directory itself + tree.entries[1].mark-- + tree.entries[3].mark-- + err := markFn(fs.Walk(tree.name)) + if err != nil { + t.Fatalf("expected no error return from Walk, got %s", err) + } + if len(errors) != 2 { + t.Errorf("expected 2 errors, got %d: %s", len(errors), errors) + } + // the inaccessible subtrees were marked manually + checkMarks(t, true) + errors = errors[0:0] + + // 4) capture errors, stop after first error. + // mark respective subtrees manually + markTree(tree.entries[1]) + markTree(tree.entries[3]) + // correct double-marking of directory itself + tree.entries[1].mark-- + tree.entries[3].mark-- + clear = false // error will stop processing + err = markFn(fs.Walk(tree.name)) + if err == nil { + t.Fatalf("expected error return from Walk") + } + if len(errors) != 1 { + t.Errorf("expected 1 error, got %d: %s", len(errors), errors) + } + // the inaccessible subtrees were marked manually + checkMarks(t, false) + errors = errors[0:0] + + // restore permissions + os.Chmod(filepath.Join(tree.name, tree.entries[1].name), 0770) + os.Chmod(filepath.Join(tree.name, tree.entries[3].name), 0770) + } + + // cleanup + if err := os.RemoveAll(tree.name); err != nil { + t.Errorf("removeTree: %v", err) + } +} + +func TestBug3486(t *testing.T) { // http://code.google.com/p/go/issues/detail?id=3486 + root, err := filepath.EvalSymlinks(runtime.GOROOT()) + if err != nil { + t.Fatal(err) + } + lib := filepath.Join(root, "lib") + src := filepath.Join(root, "src") + seenSrc := false + walker := fs.Walk(root) + for walker.Step() { + if walker.Err() != nil { + t.Fatal(walker.Err()) + } + + switch walker.Path() { + case lib: + walker.SkipDir() + case src: + seenSrc = true + } + } + if !seenSrc { + t.Fatalf("%q not seen", src) + } +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/CONTRIBUTORS b/Godeps/_workspace/src/github.com/pkg/sftp/CONTRIBUTORS new file mode 100644 index 000000000..7eff82364 --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/CONTRIBUTORS @@ -0,0 +1,2 @@ +Dave Cheney +Saulius Gurklys diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/LICENSE b/Godeps/_workspace/src/github.com/pkg/sftp/LICENSE new file mode 100644 index 000000000..b7b53921e --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/LICENSE @@ -0,0 +1,9 @@ +Copyright (c) 2013, Dave Cheney +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/README.md b/Godeps/_workspace/src/github.com/pkg/sftp/README.md new file mode 100644 index 000000000..d7058eb61 --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/README.md @@ -0,0 +1,27 @@ +sftp +---- + +The `sftp` package provides support for file system operations on remote ssh servers using the SFTP subsystem. + +[![Build Status](https://drone.io/github.com/pkg/sftp/status.png)](https://drone.io/github.com/pkg/sftp/latest) + +usage and examples +------------------ + +See [godoc.org/github.com/pkg/sftp](http://godoc.org/github.com/pkg/sftp) for examples and usage. + +The basic operation of the package mirrors the facilities of the [os](http://golang.org/pkg/os) package. + +The Walker interface for directory traversal is heavily inspired by Keith Rarick's [fs](http://godoc.org/github.com/kr/fs) package. + +roadmap +------- + + * Currently all traffic with the server is serialized, this can be improved by allowing overlapping requests/responses. + * There is way too much duplication in the Client methods. If there was an unmarshal(interface{}) method this would reduce a heap of the duplication. + * Implement integration tests by talking directly to a real opensftp-server process. This shouldn't be too difficult to implement with a small refactoring to the sftp.NewClient method. These tests should be gated on an -sftp.integration test flag. _in progress_ + +contributing +------------ + +Features, Issues, and Pull Requests are always welcome. diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/attrs.go b/Godeps/_workspace/src/github.com/pkg/sftp/attrs.go new file mode 100644 index 000000000..0d37db080 --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/attrs.go @@ -0,0 +1,138 @@ +package sftp + +// ssh_FXP_ATTRS support +// see http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-5 + +import ( + "os" + "syscall" + "time" +) + +const ( + ssh_FILEXFER_ATTR_SIZE = 0x00000001 + ssh_FILEXFER_ATTR_UIDGID = 0x00000002 + ssh_FILEXFER_ATTR_PERMISSIONS = 0x00000004 + ssh_FILEXFER_ATTR_ACMODTIME = 0x00000008 + ssh_FILEXFER_ATTR_EXTENDED = 0x80000000 +) + +// fileInfo is an artificial type designed to satisfy os.FileInfo. +type fileInfo struct { + name string + size int64 + mode os.FileMode + mtime time.Time + sys interface{} +} + +// Name returns the base name of the file. +func (fi *fileInfo) Name() string { return fi.name } + +// Size returns the length in bytes for regular files; system-dependent for others. +func (fi *fileInfo) Size() int64 { return fi.size } + +// Mode returns file mode bits. +func (fi *fileInfo) Mode() os.FileMode { return fi.mode } + +// ModTime returns the last modification time of the file. +func (fi *fileInfo) ModTime() time.Time { return fi.mtime } + +// IsDir returns true if the file is a directory. +func (fi *fileInfo) IsDir() bool { return fi.Mode().IsDir() } + +func (fi *fileInfo) Sys() interface{} { return fi.sys } + +// FileStat holds the original unmarshalled values from a call to READDIR or *STAT. +// It is exported for the purposes of accessing the raw values via os.FileInfo.Sys() +type FileStat struct { + Size uint64 + Mode uint32 + Mtime uint32 + Atime uint32 + Uid uint32 + Gid uint32 + Extended []StatExtended +} + +type StatExtended struct { + ExtType string + ExtData string +} + +func fileInfoFromStat(st *FileStat, name string) os.FileInfo { + fs := &fileInfo{ + name: name, + size: int64(st.Size), + mode: toFileMode(st.Mode), + mtime: time.Unix(int64(st.Mtime), 0), + sys: st, + } + return fs +} + +func unmarshalAttrs(b []byte) (*FileStat, []byte) { + flags, b := unmarshalUint32(b) + var fs FileStat + if flags&ssh_FILEXFER_ATTR_SIZE == ssh_FILEXFER_ATTR_SIZE { + fs.Size, b = unmarshalUint64(b) + } + if flags&ssh_FILEXFER_ATTR_UIDGID == ssh_FILEXFER_ATTR_UIDGID { + fs.Uid, b = unmarshalUint32(b) + } + if flags&ssh_FILEXFER_ATTR_UIDGID == ssh_FILEXFER_ATTR_UIDGID { + fs.Gid, b = unmarshalUint32(b) + } + if flags&ssh_FILEXFER_ATTR_PERMISSIONS == ssh_FILEXFER_ATTR_PERMISSIONS { + fs.Mode, b = unmarshalUint32(b) + } + if flags&ssh_FILEXFER_ATTR_ACMODTIME == ssh_FILEXFER_ATTR_ACMODTIME { + fs.Atime, b = unmarshalUint32(b) + fs.Mtime, b = unmarshalUint32(b) + } + if flags&ssh_FILEXFER_ATTR_EXTENDED == ssh_FILEXFER_ATTR_EXTENDED { + var count uint32 + count, b = unmarshalUint32(b) + ext := make([]StatExtended, count, count) + for i := uint32(0); i < count; i++ { + var typ string + var data string + typ, b = unmarshalString(b) + data, b = unmarshalString(b) + ext[i] = StatExtended{typ, data} + } + fs.Extended = ext + } + return &fs, b +} + +// toFileMode converts sftp filemode bits to the os.FileMode specification +func toFileMode(mode uint32) os.FileMode { + var fm = os.FileMode(mode & 0777) + switch mode & syscall.S_IFMT { + case syscall.S_IFBLK: + fm |= os.ModeDevice + case syscall.S_IFCHR: + fm |= os.ModeDevice | os.ModeCharDevice + case syscall.S_IFDIR: + fm |= os.ModeDir + case syscall.S_IFIFO: + fm |= os.ModeNamedPipe + case syscall.S_IFLNK: + fm |= os.ModeSymlink + case syscall.S_IFREG: + // nothing to do + case syscall.S_IFSOCK: + fm |= os.ModeSocket + } + if mode&syscall.S_ISGID != 0 { + fm |= os.ModeSetgid + } + if mode&syscall.S_ISUID != 0 { + fm |= os.ModeSetuid + } + if mode&syscall.S_ISVTX != 0 { + fm |= os.ModeSticky + } + return fm +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/attrs_test.go b/Godeps/_workspace/src/github.com/pkg/sftp/attrs_test.go new file mode 100644 index 000000000..d55290599 --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/attrs_test.go @@ -0,0 +1,45 @@ +package sftp + +import ( + "bytes" + "os" + "reflect" + "testing" + "time" +) + +// ensure that attrs implemenst os.FileInfo +var _ os.FileInfo = new(fileInfo) + +var unmarshalAttrsTests = []struct { + b []byte + want *fileInfo + rest []byte +}{ + {marshal(nil, struct{ Flags uint32 }{}), &fileInfo{mtime: time.Unix(int64(0), 0)}, nil}, + {marshal(nil, struct { + Flags uint32 + Size uint64 + }{ssh_FILEXFER_ATTR_SIZE, 20}), &fileInfo{size: 20, mtime: time.Unix(int64(0), 0)}, nil}, + {marshal(nil, struct { + Flags uint32 + Size uint64 + Permissions uint32 + }{ssh_FILEXFER_ATTR_SIZE | ssh_FILEXFER_ATTR_PERMISSIONS, 20, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil}, + {marshal(nil, struct { + Flags uint32 + Size uint64 + Uid, Gid, Permissions uint32 + }{ssh_FILEXFER_ATTR_SIZE | ssh_FILEXFER_ATTR_UIDGID | ssh_FILEXFER_ATTR_UIDGID | ssh_FILEXFER_ATTR_PERMISSIONS, 20, 1000, 1000, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil}, +} + +func TestUnmarshalAttrs(t *testing.T) { + for _, tt := range unmarshalAttrsTests { + stat, rest := unmarshalAttrs(tt.b) + got := fileInfoFromStat(stat, "") + tt.want.sys = got.Sys() + if !reflect.DeepEqual(got, tt.want) || !bytes.Equal(tt.rest, rest) { + t.Errorf("unmarshalAttrs(%#v): want %#v, %#v, got: %#v, %#v", tt.b, tt.want, tt.rest, got, rest) + } + } +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/client.go b/Godeps/_workspace/src/github.com/pkg/sftp/client.go new file mode 100644 index 000000000..f48b0be7b --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/client.go @@ -0,0 +1,717 @@ +package sftp + +import ( + "encoding" + "io" + "os" + "path" + "sync" + "time" + + "github.com/kr/fs" + + "golang.org/x/crypto/ssh" +) + +// New creates a new SFTP client on conn. +func NewClient(conn *ssh.Client) (*Client, error) { + s, err := conn.NewSession() + if err != nil { + return nil, err + } + if err := s.RequestSubsystem("sftp"); err != nil { + return nil, err + } + pw, err := s.StdinPipe() + if err != nil { + return nil, err + } + pr, err := s.StdoutPipe() + if err != nil { + return nil, err + } + + return NewClientPipe(pr, pw) +} + +// NewClientPipe creates a new SFTP client given a Reader and a WriteCloser. +// This can be used for connecting to an SFTP server over TCP/TLS or by using +// the system's ssh client program (e.g. via exec.Command). +func NewClientPipe(rd io.Reader, wr io.WriteCloser) (*Client, error) { + sftp := &Client{ + w: wr, + r: rd, + } + if err := sftp.sendInit(); err != nil { + return nil, err + } + return sftp, sftp.recvVersion() +} + +// Client represents an SFTP session on a *ssh.ClientConn SSH connection. +// Multiple Clients can be active on a single SSH connection, and a Client +// may be called concurrently from multiple Goroutines. +// +// Client implements the github.com/kr/fs.FileSystem interface. +type Client struct { + w io.WriteCloser + r io.Reader + mu sync.Mutex // locks mu and seralises commands to the server + nextid uint32 +} + +// Close closes the SFTP session. +func (c *Client) Close() error { return c.w.Close() } + +// Create creates the named file mode 0666 (before umask), truncating it if +// it already exists. If successful, methods on the returned File can be +// used for I/O; the associated file descriptor has mode O_RDWR. +func (c *Client) Create(path string) (*File, error) { + return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) +} + +const sftpProtocolVersion = 3 // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 + +func (c *Client) sendInit() error { + return sendPacket(c.w, sshFxInitPacket{ + Version: sftpProtocolVersion, // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 + }) +} + +// returns the current value of c.nextid and increments it +// callers is expected to hold c.mu +func (c *Client) nextId() uint32 { + v := c.nextid + c.nextid++ + return v +} + +func (c *Client) recvVersion() error { + typ, data, err := recvPacket(c.r) + if err != nil { + return err + } + if typ != ssh_FXP_VERSION { + return &unexpectedPacketErr{ssh_FXP_VERSION, typ} + } + + version, _ := unmarshalUint32(data) + if version != sftpProtocolVersion { + return &unexpectedVersionErr{sftpProtocolVersion, version} + } + + return nil +} + +// Walk returns a new Walker rooted at root. +func (c *Client) Walk(root string) *fs.Walker { + return fs.WalkFS(root, c) +} + +// ReadDir reads the directory named by dirname and returns a list of +// directory entries. +func (c *Client) ReadDir(p string) ([]os.FileInfo, error) { + handle, err := c.opendir(p) + if err != nil { + return nil, err + } + defer c.close(handle) // this has to defer earlier than the lock below + var attrs []os.FileInfo + c.mu.Lock() + defer c.mu.Unlock() + var done = false + for !done { + id := c.nextId() + typ, data, err1 := c.sendRequest(sshFxpReaddirPacket{ + Id: id, + Handle: handle, + }) + if err1 != nil { + err = err1 + done = true + break + } + switch typ { + case ssh_FXP_NAME: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIdErr{id, sid} + } + count, data := unmarshalUint32(data) + for i := uint32(0); i < count; i++ { + var filename string + filename, data = unmarshalString(data) + _, data = unmarshalString(data) // discard longname + var attr *FileStat + attr, data = unmarshalAttrs(data) + if filename == "." || filename == ".." { + continue + } + attrs = append(attrs, fileInfoFromStat(attr, path.Base(filename))) + } + case ssh_FXP_STATUS: + // TODO(dfc) scope warning! + err = eofOrErr(unmarshalStatus(id, data)) + done = true + default: + return nil, unimplementedPacketErr(typ) + } + } + if err == io.EOF { + err = nil + } + return attrs, err +} +func (c *Client) opendir(path string) (string, error) { + c.mu.Lock() + defer c.mu.Unlock() + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpOpendirPacket{ + Id: id, + Path: path, + }) + if err != nil { + return "", err + } + switch typ { + case ssh_FXP_HANDLE: + sid, data := unmarshalUint32(data) + if sid != id { + return "", &unexpectedIdErr{id, sid} + } + handle, _ := unmarshalString(data) + return handle, nil + case ssh_FXP_STATUS: + return "", unmarshalStatus(id, data) + default: + return "", unimplementedPacketErr(typ) + } +} + +func (c *Client) Lstat(p string) (os.FileInfo, error) { + c.mu.Lock() + defer c.mu.Unlock() + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpLstatPacket{ + Id: id, + Path: p, + }) + if err != nil { + return nil, err + } + switch typ { + case ssh_FXP_ATTRS: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIdErr{id, sid} + } + attr, _ := unmarshalAttrs(data) + return fileInfoFromStat(attr, path.Base(p)), nil + case ssh_FXP_STATUS: + return nil, unmarshalStatus(id, data) + default: + return nil, unimplementedPacketErr(typ) + } +} + +// ReadLink reads the target of a symbolic link. +func (c *Client) ReadLink(p string) (string, error) { + c.mu.Lock() + defer c.mu.Unlock() + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpReadlinkPacket{ + Id: id, + Path: p, + }) + if err != nil { + return "", err + } + switch typ { + case ssh_FXP_NAME: + sid, data := unmarshalUint32(data) + if sid != id { + return "", &unexpectedIdErr{id, sid} + } + count, data := unmarshalUint32(data) + if count != 1 { + return "", unexpectedCount(1, count) + } + filename, _ := unmarshalString(data) // ignore dummy attributes + return filename, nil + case ssh_FXP_STATUS: + return "", unmarshalStatus(id, data) + default: + return "", unimplementedPacketErr(typ) + } +} + +// setstat is a convience wrapper to allow for changing of various parts of the file descriptor. +func (c *Client) setstat(path string, flags uint32, attrs interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpSetstatPacket{ + Id: id, + Path: path, + Flags: flags, + Attrs: attrs, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return okOrErr(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// Chtimes changes the access and modification times of the named file. +func (c *Client) Chtimes(path string, atime time.Time, mtime time.Time) error { + type times struct { + Atime uint32 + Mtime uint32 + } + attrs := times{uint32(atime.Unix()), uint32(mtime.Unix())} + return c.setstat(path, ssh_FILEXFER_ATTR_ACMODTIME, attrs) +} + +// Chown changes the user and group owners of the named file. +func (c *Client) Chown(path string, uid, gid int) error { + type owner struct { + Uid uint32 + Gid uint32 + } + attrs := owner{uint32(uid), uint32(gid)} + return c.setstat(path, ssh_FILEXFER_ATTR_UIDGID, attrs) +} + +// Chmod changes the permissions of the named file. +func (c *Client) Chmod(path string, mode os.FileMode) error { + return c.setstat(path, ssh_FILEXFER_ATTR_PERMISSIONS, uint32(mode)) +} + +// Truncate sets the size of the named file. Although it may be safely assumed +// that if the size is less than its current size it will be truncated to fit, +// the SFTP protocol does not specify what behavior the server should do when setting +// size greater than the current size. +func (c *Client) Truncate(path string, size int64) error { + return c.setstat(path, ssh_FILEXFER_ATTR_SIZE, uint64(size)) +} + +// Open opens the named file for reading. If successful, methods on the +// returned file can be used for reading; the associated file descriptor +// has mode O_RDONLY. +func (c *Client) Open(path string) (*File, error) { + return c.open(path, flags(os.O_RDONLY)) +} + +// OpenFile is the generalized open call; most users will use Open or +// Create instead. It opens the named file with specified flag (O_RDONLY +// etc.). If successful, methods on the returned File can be used for I/O. +func (c *Client) OpenFile(path string, f int) (*File, error) { + return c.open(path, flags(f)) +} + +func (c *Client) open(path string, pflags uint32) (*File, error) { + c.mu.Lock() + defer c.mu.Unlock() + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpOpenPacket{ + Id: id, + Path: path, + Pflags: pflags, + }) + if err != nil { + return nil, err + } + switch typ { + case ssh_FXP_HANDLE: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIdErr{id, sid} + } + handle, _ := unmarshalString(data) + return &File{c: c, path: path, handle: handle}, nil + case ssh_FXP_STATUS: + return nil, unmarshalStatus(id, data) + default: + return nil, unimplementedPacketErr(typ) + } +} + +// readAt reads len(buf) bytes from the remote file indicated by handle starting +// from offset. +func (c *Client) readAt(handle string, offset uint64, buf []byte) (uint32, error) { + c.mu.Lock() + defer c.mu.Unlock() + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpReadPacket{ + Id: id, + Handle: handle, + Offset: offset, + Len: uint32(len(buf)), + }) + if err != nil { + return 0, err + } + switch typ { + case ssh_FXP_DATA: + sid, data := unmarshalUint32(data) + if sid != id { + return 0, &unexpectedIdErr{id, sid} + } + l, data := unmarshalUint32(data) + n := copy(buf, data[:l]) + return uint32(n), nil + case ssh_FXP_STATUS: + return 0, eofOrErr(unmarshalStatus(id, data)) + default: + return 0, unimplementedPacketErr(typ) + } +} + +// close closes a handle handle previously returned in the response +// to SSH_FXP_OPEN or SSH_FXP_OPENDIR. The handle becomes invalid +// immediately after this request has been sent. +func (c *Client) close(handle string) error { + c.mu.Lock() + defer c.mu.Unlock() + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpClosePacket{ + Id: id, + Handle: handle, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return okOrErr(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +func (c *Client) fstat(handle string) (*FileStat, error) { + c.mu.Lock() + defer c.mu.Unlock() + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpFstatPacket{ + Id: id, + Handle: handle, + }) + if err != nil { + return nil, err + } + switch typ { + case ssh_FXP_ATTRS: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIdErr{id, sid} + } + attr, _ := unmarshalAttrs(data) + return attr, nil + case ssh_FXP_STATUS: + return nil, unmarshalStatus(id, data) + default: + return nil, unimplementedPacketErr(typ) + } +} + +// Join joins any number of path elements into a single path, adding a +// separating slash if necessary. The result is Cleaned; in particular, all +// empty strings are ignored. +func (c *Client) Join(elem ...string) string { return path.Join(elem...) } + +// Remove removes the specified file or directory. An error will be returned if no +// file or directory with the specified path exists, or if the specified directory +// is not empty. +func (c *Client) Remove(path string) error { + err := c.removeFile(path) + if status, ok := err.(*StatusError); ok && status.Code == ssh_FX_FAILURE { + err = c.removeDirectory(path) + } + return err +} + +func (c *Client) removeFile(path string) error { + c.mu.Lock() + defer c.mu.Unlock() + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpRemovePacket{ + Id: id, + Filename: path, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return okOrErr(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +func (c *Client) removeDirectory(path string) error { + c.mu.Lock() + defer c.mu.Unlock() + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpRmdirPacket{ + Id: id, + Path: path, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return okOrErr(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// Rename renames a file. +func (c *Client) Rename(oldname, newname string) error { + c.mu.Lock() + defer c.mu.Unlock() + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpRenamePacket{ + Id: id, + Oldpath: oldname, + Newpath: newname, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return okOrErr(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +func (c *Client) sendRequest(p encoding.BinaryMarshaler) (byte, []byte, error) { + if err := sendPacket(c.w, p); err != nil { + return 0, nil, err + } + return recvPacket(c.r) +} + +// writeAt writes len(buf) bytes from the remote file indicated by handle starting +// from offset. +func (c *Client) writeAt(handle string, offset uint64, buf []byte) (uint32, error) { + c.mu.Lock() + defer c.mu.Unlock() + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpWritePacket{ + Id: id, + Handle: handle, + Offset: offset, + Length: uint32(len(buf)), + Data: buf, + }) + if err != nil { + return 0, err + } + switch typ { + case ssh_FXP_STATUS: + if err := okOrErr(unmarshalStatus(id, data)); err != nil { + return 0, err + } + return uint32(len(buf)), nil + default: + return 0, unimplementedPacketErr(typ) + } +} + +// Creates the specified directory. An error will be returned if a file or +// directory with the specified path already exists, or if the directory's +// parent folder does not exist (the method cannot create complete paths). +func (c *Client) Mkdir(path string) error { + c.mu.Lock() + defer c.mu.Unlock() + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpMkdirPacket{ + Id: id, + Path: path, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return okOrErr(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// File represents a remote file. +type File struct { + c *Client + path string + handle string + offset uint64 // current offset within remote file +} + +// Close closes the File, rendering it unusable for I/O. It returns an +// error, if any. +func (f *File) Close() error { + return f.c.close(f.handle) +} + +// Read reads up to len(b) bytes from the File. It returns the number of +// bytes read and an error, if any. EOF is signaled by a zero count with +// err set to io.EOF. +func (f *File) Read(b []byte) (int, error) { + var read int + for len(b) > 0 { + n, err := f.c.readAt(f.handle, f.offset, b[:min(len(b), maxWritePacket)]) + f.offset += uint64(n) + read += int(n) + if err != nil { + return read, err + } + b = b[n:] + } + return read, nil +} + +// Stat returns the FileInfo structure describing file. If there is an +// error. +func (f *File) Stat() (os.FileInfo, error) { + fs, err := f.c.fstat(f.handle) + if err != nil { + return nil, err + } + return fileInfoFromStat(fs, path.Base(f.path)), nil +} + +// clamp writes to less than 32k +const maxWritePacket = 1 << 15 + +// Write writes len(b) bytes to the File. It returns the number of bytes +// written and an error, if any. Write returns a non-nil error when n != +// len(b). +func (f *File) Write(b []byte) (int, error) { + var written int + for len(b) > 0 { + n, err := f.c.writeAt(f.handle, f.offset, b[:min(len(b), maxWritePacket)]) + f.offset += uint64(n) + written += int(n) + if err != nil { + return written, err + } + b = b[n:] + } + return written, nil +} + +// Seek implements io.Seeker by setting the client offset for the next Read or +// Write. It returns the next offset read. Seeking before or after the end of +// the file is undefined. Seeking relative to the end calls Stat. +func (f *File) Seek(offset int64, whence int) (int64, error) { + switch whence { + case os.SEEK_SET: + f.offset = uint64(offset) + case os.SEEK_CUR: + f.offset = uint64(int64(f.offset) + offset) + case os.SEEK_END: + fi, err := f.Stat() + if err != nil { + return int64(f.offset), err + } + f.offset = uint64(fi.Size() + offset) + default: + return int64(f.offset), unimplementedSeekWhence(whence) + } + return int64(f.offset), nil +} + +// Chown changes the uid/gid of the current file. +func (f *File) Chown(uid, gid int) error { + return f.c.Chown(f.path, uid, gid) +} + +// Chmod changes the permissions of the current file. +func (f *File) Chmod(mode os.FileMode) error { + return f.c.Chmod(f.path, mode) +} + +// Truncate sets the size of the current file. Although it may be safely assumed +// that if the size is less than its current size it will be truncated to fit, +// the SFTP protocol does not specify what behavior the server should do when setting +// size greater than the current size. +func (f *File) Truncate(size int64) error { + return f.c.Truncate(f.path, size) +} + +func min(a, b int) int { + if a > b { + return b + } + return a +} + +// okOrErr returns nil if Err.Code is SSH_FX_OK, otherwise it returns the error. +func okOrErr(err error) error { + if err, ok := err.(*StatusError); ok && err.Code == ssh_FX_OK { + return nil + } + return err +} + +func eofOrErr(err error) error { + if err, ok := err.(*StatusError); ok && err.Code == ssh_FX_EOF { + return io.EOF + } + return err +} + +func unmarshalStatus(id uint32, data []byte) error { + sid, data := unmarshalUint32(data) + if sid != id { + return &unexpectedIdErr{id, sid} + } + code, data := unmarshalUint32(data) + msg, data := unmarshalString(data) + lang, _ := unmarshalString(data) + return &StatusError{ + Code: code, + msg: msg, + lang: lang, + } +} + +// flags converts the flags passed to OpenFile into ssh flags. +// Unsupported flags are ignored. +func flags(f int) uint32 { + var out uint32 + switch f & os.O_WRONLY { + case os.O_WRONLY: + out |= ssh_FXF_WRITE + case os.O_RDONLY: + out |= ssh_FXF_READ + } + if f&os.O_RDWR == os.O_RDWR { + out |= ssh_FXF_READ | ssh_FXF_WRITE + } + if f&os.O_APPEND == os.O_APPEND { + out |= ssh_FXF_APPEND + } + if f&os.O_CREATE == os.O_CREATE { + out |= ssh_FXF_CREAT + } + if f&os.O_TRUNC == os.O_TRUNC { + out |= ssh_FXF_TRUNC + } + if f&os.O_EXCL == os.O_EXCL { + out |= ssh_FXF_EXCL + } + return out +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/client_integration_test.go b/Godeps/_workspace/src/github.com/pkg/sftp/client_integration_test.go new file mode 100644 index 000000000..828deea35 --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/client_integration_test.go @@ -0,0 +1,898 @@ +package sftp + +// sftp integration tests +// enable with -integration + +import ( + "crypto/sha1" + "flag" + "io" + "io/ioutil" + "math/rand" + "os" + "os/exec" + "path" + "path/filepath" + "reflect" + "testing" + "testing/quick" + + "github.com/kr/fs" +) + +const ( + READONLY = true + READWRITE = false + + debuglevel = "ERROR" // set to "DEBUG" for debugging +) + +var testIntegration = flag.Bool("integration", false, "perform integration tests against sftp server process") +var testSftp = flag.String("sftp", "/usr/lib/openssh/sftp-server", "location of the sftp server binary") + +// testClient returns a *Client connected to a localy running sftp-server +// the *exec.Cmd returned must be defer Wait'd. +func testClient(t testing.TB, readonly bool) (*Client, *exec.Cmd) { + if !*testIntegration { + t.Skip("skipping intergration test") + } + cmd := exec.Command(*testSftp, "-e", "-R", "-l", debuglevel) // log to stderr, read only + if !readonly { + cmd = exec.Command(*testSftp, "-e", "-l", debuglevel) // log to stderr + } + cmd.Stderr = os.Stdout + pw, err := cmd.StdinPipe() + if err != nil { + t.Fatal(err) + } + pr, err := cmd.StdoutPipe() + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Skipf("could not start sftp-server process: %v", err) + } + + sftp, err := NewClientPipe(pr, pw) + if err != nil { + t.Fatal(err) + } + + if err := sftp.sendInit(); err != nil { + defer cmd.Wait() + t.Fatal(err) + } + if err := sftp.recvVersion(); err != nil { + defer cmd.Wait() + t.Fatal(err) + } + return sftp, cmd +} + +func TestNewClient(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + + if err := sftp.Close(); err != nil { + t.Fatal(err) + } +} + +func TestClientLstat(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + want, err := os.Lstat(f.Name()) + if err != nil { + t.Fatal(err) + } + + got, err := sftp.Lstat(f.Name()) + if err != nil { + t.Fatal(err) + } + + if !sameFile(want, got) { + t.Fatalf("Lstat(%q): want %#v, got %#v", f.Name(), want, got) + } +} + +func TestClientLstatMissing(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + os.Remove(f.Name()) + + _, err = sftp.Lstat(f.Name()) + if err1, ok := err.(*StatusError); !ok || err1.Code != ssh_FX_NO_SUCH_FILE { + t.Fatalf("Lstat: want: %v, got %#v", ssh_FX_NO_SUCH_FILE, err) + } +} + +func TestClientMkdir(t *testing.T) { + sftp, cmd := testClient(t, READWRITE) + defer cmd.Wait() + defer sftp.Close() + + dir, err := ioutil.TempDir("", "sftptest") + if err != nil { + t.Fatal(err) + } + sub := path.Join(dir, "mkdir1") + if err := sftp.Mkdir(sub); err != nil { + t.Fatal(err) + } + if _, err := os.Lstat(sub); err != nil { + t.Fatal(err) + } +} + +func TestClientOpen(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + got, err := sftp.Open(f.Name()) + if err != nil { + t.Fatal(err) + } + if err := got.Close(); err != nil { + t.Fatal(err) + } +} + +const seekBytes = 128 * 1024 + +type seek struct { + offset int64 +} + +func (s seek) Generate(r *rand.Rand, _ int) reflect.Value { + s.offset = int64(r.Int31n(seekBytes)) + return reflect.ValueOf(s) +} + +func (s seek) set(t *testing.T, r io.ReadSeeker) { + if _, err := r.Seek(s.offset, os.SEEK_SET); err != nil { + t.Fatalf("error while seeking with %+v: %v", s, err) + } +} + +func (s seek) current(t *testing.T, r io.ReadSeeker) { + const mid = seekBytes / 2 + + skip := s.offset / 2 + if s.offset > mid { + skip = -skip + } + + if _, err := r.Seek(mid, os.SEEK_SET); err != nil { + t.Fatalf("error seeking to midpoint with %+v: %v", s, err) + } + if _, err := r.Seek(skip, os.SEEK_CUR); err != nil { + t.Fatalf("error seeking from %d with %+v: %v", mid, s, err) + } +} + +func (s seek) end(t *testing.T, r io.ReadSeeker) { + if _, err := r.Seek(-s.offset, os.SEEK_END); err != nil { + t.Fatalf("error seeking from end with %+v: %v", s, err) + } +} + +func TestClientSeek(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + fOS, err := ioutil.TempFile("", "seek-test") + if err != nil { + t.Fatal(err) + } + defer fOS.Close() + + fSFTP, err := sftp.Open(fOS.Name()) + if err != nil { + t.Fatal(err) + } + defer fSFTP.Close() + + writeN(t, fOS, seekBytes) + + if err := quick.CheckEqual( + func(s seek) (string, int64) { s.set(t, fOS); return readHash(t, fOS) }, + func(s seek) (string, int64) { s.set(t, fSFTP); return readHash(t, fSFTP) }, + nil, + ); err != nil { + t.Errorf("Seek: expected equal absolute seeks: %v", err) + } + + if err := quick.CheckEqual( + func(s seek) (string, int64) { s.current(t, fOS); return readHash(t, fOS) }, + func(s seek) (string, int64) { s.current(t, fSFTP); return readHash(t, fSFTP) }, + nil, + ); err != nil { + t.Errorf("Seek: expected equal seeks from middle: %v", err) + } + + if err := quick.CheckEqual( + func(s seek) (string, int64) { s.end(t, fOS); return readHash(t, fOS) }, + func(s seek) (string, int64) { s.end(t, fSFTP); return readHash(t, fSFTP) }, + nil, + ); err != nil { + t.Errorf("Seek: expected equal seeks from end: %v", err) + } +} + +func TestClientCreate(t *testing.T) { + sftp, cmd := testClient(t, READWRITE) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer f.Close() + defer os.Remove(f.Name()) + + f2, err := sftp.Create(f.Name()) + if err != nil { + t.Fatal(err) + } + defer f2.Close() +} + +func TestClientAppend(t *testing.T) { + sftp, cmd := testClient(t, READWRITE) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer f.Close() + defer os.Remove(f.Name()) + + f2, err := sftp.OpenFile(f.Name(), os.O_RDWR|os.O_APPEND) + if err != nil { + t.Fatal(err) + } + defer f2.Close() +} + +func TestClientCreateFailed(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer f.Close() + defer os.Remove(f.Name()) + + f2, err := sftp.Create(f.Name()) + if err1, ok := err.(*StatusError); !ok || err1.Code != ssh_FX_PERMISSION_DENIED { + t.Fatalf("Create: want: %v, got %#v", ssh_FX_PERMISSION_DENIED, err) + } + if err == nil { + f2.Close() + } +} + +func TestClientFileStat(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + want, err := os.Lstat(f.Name()) + if err != nil { + t.Fatal(err) + } + + f2, err := sftp.Open(f.Name()) + if err != nil { + t.Fatal(err) + } + + got, err := f2.Stat() + if err != nil { + t.Fatal(err) + } + + if !sameFile(want, got) { + t.Fatalf("Lstat(%q): want %#v, got %#v", f.Name(), want, got) + } +} + +func TestClientRemove(t *testing.T) { + sftp, cmd := testClient(t, READWRITE) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + if err := sftp.Remove(f.Name()); err != nil { + t.Fatal(err) + } + if _, err := os.Lstat(f.Name()); !os.IsNotExist(err) { + t.Fatal(err) + } +} + +func TestClientRemoveDir(t *testing.T) { + sftp, cmd := testClient(t, READWRITE) + defer cmd.Wait() + defer sftp.Close() + + dir, err := ioutil.TempDir("", "sftptest") + if err != nil { + t.Fatal(err) + } + if err := sftp.Remove(dir); err != nil { + t.Fatal(err) + } + if _, err := os.Lstat(dir); !os.IsNotExist(err) { + t.Fatal(err) + } +} + +func TestClientRemoveFailed(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + if err := sftp.Remove(f.Name()); err == nil { + t.Fatalf("Remove(%v): want: permission denied, got %v", f.Name(), err) + } + if _, err := os.Lstat(f.Name()); err != nil { + t.Fatal(err) + } +} + +func TestClientRename(t *testing.T) { + sftp, cmd := testClient(t, READWRITE) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + f2 := f.Name() + ".new" + if err := sftp.Rename(f.Name(), f2); err != nil { + t.Fatal(err) + } + if _, err := os.Lstat(f.Name()); !os.IsNotExist(err) { + t.Fatal(err) + } + if _, err := os.Lstat(f2); err != nil { + t.Fatal(err) + } +} + +func TestClientReadLine(t *testing.T) { + sftp, cmd := testClient(t, READWRITE) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + f2 := f.Name() + ".sym" + if err := os.Symlink(f.Name(), f2); err != nil { + t.Fatal(err) + } + if _, err := sftp.ReadLink(f2); err != nil { + t.Fatal(err) + } +} + +func sameFile(want, got os.FileInfo) bool { + return want.Name() == got.Name() && + want.Size() == got.Size() +} + +var clientReadTests = []struct { + n int64 +}{ + {0}, + {1}, + {1000}, + {1024}, + {1025}, + {2048}, + {4096}, + {1 << 12}, + {1 << 13}, + {1 << 14}, + {1 << 15}, + {1 << 16}, + {1 << 17}, + {1 << 18}, + {1 << 19}, + {1 << 20}, +} + +func TestClientRead(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + for _, tt := range clientReadTests { + f, err := ioutil.TempFile(d, "read-test") + if err != nil { + t.Fatal(err) + } + defer f.Close() + hash := writeN(t, f, tt.n) + f2, err := sftp.Open(f.Name()) + if err != nil { + t.Fatal(err) + } + defer f2.Close() + hash2, n := readHash(t, f2) + if hash != hash2 || tt.n != n { + t.Errorf("Read: hash: want: %q, got %q, read: want: %v, got %v", hash, hash2, tt.n, n) + } + } +} + +// readHash reads r until EOF returning the number of bytes read +// and the hash of the contents. +func readHash(t *testing.T, r io.Reader) (string, int64) { + h := sha1.New() + tr := io.TeeReader(r, h) + read, err := io.Copy(ioutil.Discard, tr) + if err != nil { + t.Fatal(err) + } + return string(h.Sum(nil)), read +} + +// writeN writes n bytes of random data to w and returns the +// hash of that data. +func writeN(t *testing.T, w io.Writer, n int64) string { + rand, err := os.Open("/dev/urandom") + if err != nil { + t.Fatal(err) + } + defer rand.Close() + + h := sha1.New() + + mw := io.MultiWriter(w, h) + + written, err := io.CopyN(mw, rand, n) + if err != nil { + t.Fatal(err) + } + if written != n { + t.Fatalf("CopyN(%v): wrote: %v", n, written) + } + return string(h.Sum(nil)) +} + +var clientWriteTests = []struct { + n int + total int64 // cumulative file size +}{ + {0, 0}, + {1, 1}, + {0, 1}, + {999, 1000}, + {24, 1024}, + {1023, 2047}, + {2048, 4095}, + {1 << 12, 8191}, + {1 << 13, 16383}, + {1 << 14, 32767}, + {1 << 15, 65535}, + {1 << 16, 131071}, + {1 << 17, 262143}, + {1 << 18, 524287}, + {1 << 19, 1048575}, + {1 << 20, 2097151}, + {1 << 21, 4194303}, +} + +func TestClientWrite(t *testing.T) { + sftp, cmd := testClient(t, READWRITE) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + f := path.Join(d, "writeTest") + w, err := sftp.Create(f) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + for _, tt := range clientWriteTests { + got, err := w.Write(make([]byte, tt.n)) + if err != nil { + t.Fatal(err) + } + if got != tt.n { + t.Errorf("Write(%v): wrote: want: %v, got %v", tt.n, tt.n, got) + } + fi, err := os.Stat(f) + if err != nil { + t.Fatal(err) + } + if total := fi.Size(); total != tt.total { + t.Errorf("Write(%v): size: want: %v, got %v", tt.n, tt.total, total) + } + } +} + +// taken from github.com/kr/fs/walk_test.go + +type PathTest struct { + path, result string +} + +type Node struct { + name string + entries []*Node // nil if the entry is a file + mark int +} + +var tree = &Node{ + "testdata", + []*Node{ + {"a", nil, 0}, + {"b", []*Node{}, 0}, + {"c", nil, 0}, + { + "d", + []*Node{ + {"x", nil, 0}, + {"y", []*Node{}, 0}, + { + "z", + []*Node{ + {"u", nil, 0}, + {"v", nil, 0}, + }, + 0, + }, + }, + 0, + }, + }, + 0, +} + +func walkTree(n *Node, path string, f func(path string, n *Node)) { + f(path, n) + for _, e := range n.entries { + walkTree(e, filepath.Join(path, e.name), f) + } +} + +func makeTree(t *testing.T) { + walkTree(tree, tree.name, func(path string, n *Node) { + if n.entries == nil { + fd, err := os.Create(path) + if err != nil { + t.Errorf("makeTree: %v", err) + return + } + fd.Close() + } else { + os.Mkdir(path, 0770) + } + }) +} + +func markTree(n *Node) { walkTree(n, "", func(path string, n *Node) { n.mark++ }) } + +func checkMarks(t *testing.T, report bool) { + walkTree(tree, tree.name, func(path string, n *Node) { + if n.mark != 1 && report { + t.Errorf("node %s mark = %d; expected 1", path, n.mark) + } + n.mark = 0 + }) +} + +// Assumes that each node name is unique. Good enough for a test. +// If clear is true, any incoming error is cleared before return. The errors +// are always accumulated, though. +func mark(path string, info os.FileInfo, err error, errors *[]error, clear bool) error { + if err != nil { + *errors = append(*errors, err) + if clear { + return nil + } + return err + } + name := info.Name() + walkTree(tree, tree.name, func(path string, n *Node) { + if n.name == name { + n.mark++ + } + }) + return nil +} + +func TestClientWalk(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + makeTree(t) + errors := make([]error, 0, 10) + clear := true + markFn := func(walker *fs.Walker) (err error) { + for walker.Step() { + err = mark(walker.Path(), walker.Stat(), walker.Err(), &errors, clear) + if err != nil { + break + } + } + return err + } + // Expect no errors. + err := markFn(sftp.Walk(tree.name)) + if err != nil { + t.Fatalf("no error expected, found: %s", err) + } + if len(errors) != 0 { + t.Fatalf("unexpected errors: %s", errors) + } + checkMarks(t, true) + errors = errors[0:0] + + // Test permission errors. Only possible if we're not root + // and only on some file systems (AFS, FAT). To avoid errors during + // all.bash on those file systems, skip during go test -short. + if os.Getuid() > 0 && !testing.Short() { + // introduce 2 errors: chmod top-level directories to 0 + os.Chmod(filepath.Join(tree.name, tree.entries[1].name), 0) + os.Chmod(filepath.Join(tree.name, tree.entries[3].name), 0) + + // 3) capture errors, expect two. + // mark respective subtrees manually + markTree(tree.entries[1]) + markTree(tree.entries[3]) + // correct double-marking of directory itself + tree.entries[1].mark-- + tree.entries[3].mark-- + err := markFn(sftp.Walk(tree.name)) + if err != nil { + t.Fatalf("expected no error return from Walk, got %s", err) + } + if len(errors) != 2 { + t.Errorf("expected 2 errors, got %d: %s", len(errors), errors) + } + // the inaccessible subtrees were marked manually + checkMarks(t, true) + errors = errors[0:0] + + // 4) capture errors, stop after first error. + // mark respective subtrees manually + markTree(tree.entries[1]) + markTree(tree.entries[3]) + // correct double-marking of directory itself + tree.entries[1].mark-- + tree.entries[3].mark-- + clear = false // error will stop processing + err = markFn(sftp.Walk(tree.name)) + if err == nil { + t.Fatalf("expected error return from Walk") + } + if len(errors) != 1 { + t.Errorf("expected 1 error, got %d: %s", len(errors), errors) + } + // the inaccessible subtrees were marked manually + checkMarks(t, false) + errors = errors[0:0] + + // restore permissions + os.Chmod(filepath.Join(tree.name, tree.entries[1].name), 0770) + os.Chmod(filepath.Join(tree.name, tree.entries[3].name), 0770) + } + + // cleanup + if err := os.RemoveAll(tree.name); err != nil { + t.Errorf("removeTree: %v", err) + } +} + +func benchmarkRead(b *testing.B, bufsize int) { + size := 10*1024*1024 + 123 // ~10MiB + + // open sftp client + sftp, cmd := testClient(b, READONLY) + defer cmd.Wait() + defer sftp.Close() + + buf := make([]byte, bufsize) + + b.ResetTimer() + b.SetBytes(int64(size)) + + for i := 0; i < b.N; i++ { + offset := 0 + + f2, err := sftp.Open("/dev/zero") + if err != nil { + b.Fatal(err) + } + defer f2.Close() + + for offset < size { + n, err := io.ReadFull(f2, buf) + offset += n + if err == io.ErrUnexpectedEOF && offset != size { + b.Fatalf("read too few bytes! want: %d, got: %d", size, n) + } + + if err != nil { + b.Fatal(err) + } + + offset += n + } + } +} + +func BenchmarkRead1k(b *testing.B) { + benchmarkRead(b, 1*1024) +} + +func BenchmarkRead16k(b *testing.B) { + benchmarkRead(b, 16*1024) +} + +func BenchmarkRead32k(b *testing.B) { + benchmarkRead(b, 32*1024) +} + +func BenchmarkRead128k(b *testing.B) { + benchmarkRead(b, 128*1024) +} + +func BenchmarkRead512k(b *testing.B) { + benchmarkRead(b, 512*1024) +} + +func BenchmarkRead1MiB(b *testing.B) { + benchmarkRead(b, 1024*1024) +} + +func BenchmarkRead4MiB(b *testing.B) { + benchmarkRead(b, 4*1024*1024) +} + +func benchmarkWrite(b *testing.B, bufsize int) { + size := 10*1024*1024 + 123 // ~10MiB + + // open sftp client + sftp, cmd := testClient(b, false) + defer cmd.Wait() + defer sftp.Close() + + data := make([]byte, size) + + b.ResetTimer() + b.SetBytes(int64(size)) + + for i := 0; i < b.N; i++ { + offset := 0 + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + + f2, err := sftp.Create(f.Name()) + if err != nil { + b.Fatal(err) + } + defer f2.Close() + + for offset < size { + n, err := f2.Write(data[offset:min(len(data), offset+bufsize)]) + if err != nil { + b.Fatal(err) + } + + if offset+n < size && n != bufsize { + b.Fatalf("wrote too few bytes! want: %d, got: %d", size, n) + } + + offset += n + } + + f2.Close() + + fi, err := os.Stat(f.Name()) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != int64(size) { + b.Fatalf("wrong file size: want %d, got %d", size, fi.Size()) + } + + os.Remove(f.Name()) + } +} + +func BenchmarkWrite1k(b *testing.B) { + benchmarkWrite(b, 1*1024) +} + +func BenchmarkWrite16k(b *testing.B) { + benchmarkWrite(b, 16*1024) +} + +func BenchmarkWrite32k(b *testing.B) { + benchmarkWrite(b, 32*1024) +} + +func BenchmarkWrite128k(b *testing.B) { + benchmarkWrite(b, 128*1024) +} + +func BenchmarkWrite512k(b *testing.B) { + benchmarkWrite(b, 512*1024) +} + +func BenchmarkWrite1MiB(b *testing.B) { + benchmarkWrite(b, 1024*1024) +} + +func BenchmarkWrite4MiB(b *testing.B) { + benchmarkWrite(b, 4*1024*1024) +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/client_test.go b/Godeps/_workspace/src/github.com/pkg/sftp/client_test.go new file mode 100644 index 000000000..9ade6d1af --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/client_test.go @@ -0,0 +1,75 @@ +package sftp + +import ( + "io" + "os" + "testing" + + "github.com/kr/fs" +) + +// assert that *Client implements fs.FileSystem +var _ fs.FileSystem = new(Client) + +// assert that *File implements io.ReadWriteCloser +var _ io.ReadWriteCloser = new(File) + +var ok = &StatusError{Code: ssh_FX_OK} +var eof = &StatusError{Code: ssh_FX_EOF} +var fail = &StatusError{Code: ssh_FX_FAILURE} + +var eofOrErrTests = []struct { + err, want error +}{ + {nil, nil}, + {eof, io.EOF}, + {ok, ok}, + {io.EOF, io.EOF}, +} + +func TestEofOrErr(t *testing.T) { + for _, tt := range eofOrErrTests { + got := eofOrErr(tt.err) + if got != tt.want { + t.Errorf("eofOrErr(%#v): want: %#v, got: %#v", tt.err, tt.want, got) + } + } +} + +var okOrErrTests = []struct { + err, want error +}{ + {nil, nil}, + {eof, eof}, + {ok, nil}, + {io.EOF, io.EOF}, +} + +func TestOkOrErr(t *testing.T) { + for _, tt := range okOrErrTests { + got := okOrErr(tt.err) + if got != tt.want { + t.Errorf("okOrErr(%#v): want: %#v, got: %#v", tt.err, tt.want, got) + } + } +} + +var flagsTests = []struct { + flags int + want uint32 +}{ + {os.O_RDONLY, ssh_FXF_READ}, + {os.O_WRONLY, ssh_FXF_WRITE}, + {os.O_RDWR, ssh_FXF_READ | ssh_FXF_WRITE}, + {os.O_RDWR | os.O_CREATE | os.O_TRUNC, ssh_FXF_READ | ssh_FXF_WRITE | ssh_FXF_CREAT | ssh_FXF_TRUNC}, + {os.O_WRONLY | os.O_APPEND, ssh_FXF_WRITE | ssh_FXF_APPEND}, +} + +func TestFlags(t *testing.T) { + for i, tt := range flagsTests { + got := flags(tt.flags) + if got != tt.want { + t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got) + } + } +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/debug.go b/Godeps/_workspace/src/github.com/pkg/sftp/debug.go new file mode 100644 index 000000000..3e264abe3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/debug.go @@ -0,0 +1,9 @@ +// +build debug + +package sftp + +import "log" + +func debug(fmt string, args ...interface{}) { + log.Printf(fmt, args...) +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/example_test.go b/Godeps/_workspace/src/github.com/pkg/sftp/example_test.go new file mode 100644 index 000000000..3f73726de --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/example_test.go @@ -0,0 +1,91 @@ +package sftp_test + +import ( + "fmt" + "log" + "os" + "os/exec" + + "golang.org/x/crypto/ssh" + + "github.com/pkg/sftp" +) + +func Example(conn *ssh.Client) { + // open an SFTP session over an existing ssh connection. + sftp, err := sftp.NewClient(conn) + if err != nil { + log.Fatal(err) + } + defer sftp.Close() + + // walk a directory + w := sftp.Walk("/home/user") + for w.Step() { + if w.Err() != nil { + continue + } + log.Println(w.Path()) + } + + // leave your mark + f, err := sftp.Create("hello.txt") + if err != nil { + log.Fatal(err) + } + if _, err := f.Write([]byte("Hello world!")); err != nil { + log.Fatal(err) + } + + // check it's there + fi, err := sftp.Lstat("hello.txt") + if err != nil { + log.Fatal(err) + } + log.Println(fi) +} + +func ExampleNewClientPipe() { + // Connect to a remote host and request the sftp subsystem via the 'ssh' + // command. This assumes that passwordless login is correctly configured. + cmd := exec.Command("ssh", "example.com", "-s", "sftp") + + // send errors from ssh to stderr + cmd.Stderr = os.Stderr + + // get stdin and stdout + wr, err := cmd.StdinPipe() + if err != nil { + log.Fatal(err) + } + rd, err := cmd.StdoutPipe() + if err != nil { + log.Fatal(err) + } + + // start the process + if err := cmd.Start(); err != nil { + log.Fatal(err) + } + defer cmd.Wait() + + // open the SFTP session + client, err := sftp.NewClientPipe(rd, wr) + if err != nil { + log.Fatal(err) + } + + // read a directory + list, err := client.ReadDir("/") + if err != nil { + log.Fatal(err) + } + + // print contents + for _, item := range list { + fmt.Println(item.Name()) + } + + // close the connection + client.Close() +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/examples/buffered-read-benchmark/main.go b/Godeps/_workspace/src/github.com/pkg/sftp/examples/buffered-read-benchmark/main.go new file mode 100644 index 000000000..5d63f18c9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/examples/buffered-read-benchmark/main.go @@ -0,0 +1,76 @@ +// buffered-read-benchmark benchmarks the peformance of reading +// from /dev/zero on the server to a []byte on the client via io.Copy. +package main + +import ( + "flag" + "fmt" + "io" + "log" + "net" + "os" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "github.com/pkg/sftp" +) + +var ( + USER = flag.String("user", os.Getenv("USER"), "ssh username") + HOST = flag.String("host", "localhost", "ssh server hostname") + PORT = flag.Int("port", 22, "ssh server port") + PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") +) + +func init() { + flag.Parse() +} + +func main() { + var auths []ssh.AuthMethod + if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) + + } + if *PASS != "" { + auths = append(auths, ssh.Password(*PASS)) + } + + config := ssh.ClientConfig{ + User: *USER, + Auth: auths, + } + addr := fmt.Sprintf("%s:%d", *HOST, *PORT) + conn, err := ssh.Dial("tcp", addr, &config) + if err != nil { + log.Fatalf("unable to connect to [%s]: %v", addr, err) + } + defer conn.Close() + + c, err := sftp.NewClient(conn) + if err != nil { + log.Fatalf("unable to start sftp subsytem: %v", err) + } + defer c.Close() + + r, err := c.Open("/dev/zero") + if err != nil { + log.Fatal(err) + } + defer r.Close() + + const size = 1e9 + + log.Printf("reading %v bytes", size) + t1 := time.Now() + n, err := io.ReadFull(r, make([]byte, size)) + if err != nil { + log.Fatal(err) + } + if n != size { + log.Fatalf("copy: expected %v bytes, got %d", size, n) + } + log.Printf("read %v bytes in %s", size, time.Since(t1)) +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/examples/buffered-write-benchmark/main.go b/Godeps/_workspace/src/github.com/pkg/sftp/examples/buffered-write-benchmark/main.go new file mode 100644 index 000000000..c0ea133a8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/examples/buffered-write-benchmark/main.go @@ -0,0 +1,82 @@ +// buffered-write-benchmark benchmarks the peformance of writing +// a single large []byte on the client to /dev/null on the server via io.Copy. +package main + +import ( + "flag" + "fmt" + "log" + "net" + "os" + "syscall" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "github.com/pkg/sftp" +) + +var ( + USER = flag.String("user", os.Getenv("USER"), "ssh username") + HOST = flag.String("host", "localhost", "ssh server hostname") + PORT = flag.Int("port", 22, "ssh server port") + PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") +) + +func init() { + flag.Parse() +} + +func main() { + var auths []ssh.AuthMethod + if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) + + } + if *PASS != "" { + auths = append(auths, ssh.Password(*PASS)) + } + + config := ssh.ClientConfig{ + User: *USER, + Auth: auths, + } + addr := fmt.Sprintf("%s:%d", *HOST, *PORT) + conn, err := ssh.Dial("tcp", addr, &config) + if err != nil { + log.Fatalf("unable to connect to [%s]: %v", addr, err) + } + defer conn.Close() + + c, err := sftp.NewClient(conn) + if err != nil { + log.Fatalf("unable to start sftp subsytem: %v", err) + } + defer c.Close() + + w, err := c.OpenFile("/dev/null", syscall.O_WRONLY) + if err != nil { + log.Fatal(err) + } + defer w.Close() + + f, err := os.Open("/dev/zero") + if err != nil { + log.Fatal(err) + } + defer f.Close() + + const size = 1e9 + + log.Printf("writing %v bytes", size) + t1 := time.Now() + n, err := w.Write(make([]byte, size)) + if err != nil { + log.Fatal(err) + } + if n != size { + log.Fatalf("copy: expected %v bytes, got %d", size, n) + } + log.Printf("wrote %v bytes in %s", size, time.Since(t1)) +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/examples/gsftp/main.go b/Godeps/_workspace/src/github.com/pkg/sftp/examples/gsftp/main.go new file mode 100644 index 000000000..a1a8824b5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/examples/gsftp/main.go @@ -0,0 +1,147 @@ +// gsftp implements a simple sftp client. +// +// gsftp understands the following commands: +// +// List a directory (and its subdirectories) +// gsftp ls DIR +// +// Fetch a remote file +// gsftp fetch FILE +// +// Put the contents of stdin to a remote file +// cat LOCALFILE | gsftp put REMOTEFILE +// +// Print the details of a remote file +// gsftp stat FILE +// +// Remove a remote file +// gsftp rm FILE +// +// Rename a file +// gsftp mv OLD NEW +// +package main + +import ( + "flag" + "fmt" + "io" + "log" + "net" + "os" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "github.com/pkg/sftp" +) + +var ( + USER = flag.String("user", os.Getenv("USER"), "ssh username") + HOST = flag.String("host", "localhost", "ssh server hostname") + PORT = flag.Int("port", 22, "ssh server port") + PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") +) + +func init() { + flag.Parse() + if len(flag.Args()) < 1 { + log.Fatal("subcommand required") + } +} + +func main() { + var auths []ssh.AuthMethod + if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) + + } + if *PASS != "" { + auths = append(auths, ssh.Password(*PASS)) + } + + config := ssh.ClientConfig{ + User: *USER, + Auth: auths, + } + addr := fmt.Sprintf("%s:%d", *HOST, *PORT) + conn, err := ssh.Dial("tcp", addr, &config) + if err != nil { + log.Fatalf("unable to connect to [%s]: %v", addr, err) + } + defer conn.Close() + + client, err := sftp.NewClient(conn) + if err != nil { + log.Fatalf("unable to start sftp subsytem: %v", err) + } + defer client.Close() + switch cmd := flag.Args()[0]; cmd { + case "ls": + if len(flag.Args()) < 2 { + log.Fatalf("%s %s: remote path required", cmd, os.Args[0]) + } + walker := client.Walk(flag.Args()[1]) + for walker.Step() { + if err := walker.Err(); err != nil { + log.Println(err) + continue + } + fmt.Println(walker.Path()) + } + case "fetch": + if len(flag.Args()) < 2 { + log.Fatalf("%s %s: remote path required", cmd, os.Args[0]) + } + f, err := client.Open(flag.Args()[1]) + if err != nil { + log.Fatal(err) + } + defer f.Close() + if _, err := io.Copy(os.Stdout, f); err != nil { + log.Fatal(err) + } + case "put": + if len(flag.Args()) < 2 { + log.Fatalf("%s %s: remote path required", cmd, os.Args[0]) + } + f, err := client.Create(flag.Args()[1]) + if err != nil { + log.Fatal(err) + } + defer f.Close() + if _, err := io.Copy(f, os.Stdin); err != nil { + log.Fatal(err) + } + case "stat": + if len(flag.Args()) < 2 { + log.Fatalf("%s %s: remote path required", cmd, os.Args[0]) + } + f, err := client.Open(flag.Args()[1]) + if err != nil { + log.Fatal(err) + } + defer f.Close() + fi, err := f.Stat() + if err != nil { + log.Fatalf("unable to stat file: %v", err) + } + fmt.Printf("%s %d %v\n", fi.Name(), fi.Size(), fi.Mode()) + case "rm": + if len(flag.Args()) < 2 { + log.Fatalf("%s %s: remote path required", cmd, os.Args[0]) + } + if err := client.Remove(flag.Args()[1]); err != nil { + log.Fatalf("unable to remove file: %v", err) + } + case "mv": + if len(flag.Args()) < 3 { + log.Fatalf("%s %s: old and new name required", cmd, os.Args[0]) + } + if err := client.Rename(flag.Args()[1], flag.Args()[2]); err != nil { + log.Fatalf("unable to rename file: %v", err) + } + default: + log.Fatalf("unknown subcommand: %v", cmd) + } +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/examples/streaming-read-benchmark/main.go b/Godeps/_workspace/src/github.com/pkg/sftp/examples/streaming-read-benchmark/main.go new file mode 100644 index 000000000..7ebdbd46e --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/examples/streaming-read-benchmark/main.go @@ -0,0 +1,83 @@ +// streaming-read-benchmark benchmarks the peformance of reading +// from /dev/zero on the server to /dev/null on the client via io.Copy. +package main + +import ( + "flag" + "fmt" + "io" + "log" + "net" + "os" + "syscall" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "github.com/pkg/sftp" +) + +var ( + USER = flag.String("user", os.Getenv("USER"), "ssh username") + HOST = flag.String("host", "localhost", "ssh server hostname") + PORT = flag.Int("port", 22, "ssh server port") + PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") +) + +func init() { + flag.Parse() +} + +func main() { + var auths []ssh.AuthMethod + if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) + + } + if *PASS != "" { + auths = append(auths, ssh.Password(*PASS)) + } + + config := ssh.ClientConfig{ + User: *USER, + Auth: auths, + } + addr := fmt.Sprintf("%s:%d", *HOST, *PORT) + conn, err := ssh.Dial("tcp", addr, &config) + if err != nil { + log.Fatalf("unable to connect to [%s]: %v", addr, err) + } + defer conn.Close() + + c, err := sftp.NewClient(conn) + if err != nil { + log.Fatalf("unable to start sftp subsytem: %v", err) + } + defer c.Close() + + r, err := c.Open("/dev/zero") + if err != nil { + log.Fatal(err) + } + defer r.Close() + + w, err := os.OpenFile("/dev/null", syscall.O_WRONLY, 0600) + if err != nil { + log.Fatal(err) + } + defer w.Close() + + const size int64 = 1e9 + + log.Printf("reading %v bytes", size) + t1 := time.Now() + n, err := io.Copy(w, io.LimitReader(r, size)) + if err != nil { + log.Fatal(err) + } + if n != size { + log.Fatalf("copy: expected %v bytes, got %d", size, n) + } + log.Printf("read %v bytes in %s", size, time.Since(t1)) +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/examples/streaming-write-benchmark/main.go b/Godeps/_workspace/src/github.com/pkg/sftp/examples/streaming-write-benchmark/main.go new file mode 100644 index 000000000..63b27b55f --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/examples/streaming-write-benchmark/main.go @@ -0,0 +1,83 @@ +// streaming-write-benchmark benchmarks the peformance of writing +// from /dev/zero on the client to /dev/null on the server via io.Copy. +package main + +import ( + "flag" + "fmt" + "io" + "log" + "net" + "os" + "syscall" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "github.com/pkg/sftp" +) + +var ( + USER = flag.String("user", os.Getenv("USER"), "ssh username") + HOST = flag.String("host", "localhost", "ssh server hostname") + PORT = flag.Int("port", 22, "ssh server port") + PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") +) + +func init() { + flag.Parse() +} + +func main() { + var auths []ssh.AuthMethod + if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) + + } + if *PASS != "" { + auths = append(auths, ssh.Password(*PASS)) + } + + config := ssh.ClientConfig{ + User: *USER, + Auth: auths, + } + addr := fmt.Sprintf("%s:%d", *HOST, *PORT) + conn, err := ssh.Dial("tcp", addr, &config) + if err != nil { + log.Fatalf("unable to connect to [%s]: %v", addr, err) + } + defer conn.Close() + + c, err := sftp.NewClient(conn) + if err != nil { + log.Fatalf("unable to start sftp subsytem: %v", err) + } + defer c.Close() + + w, err := c.OpenFile("/dev/null", syscall.O_WRONLY) + if err != nil { + log.Fatal(err) + } + defer w.Close() + + f, err := os.Open("/dev/zero") + if err != nil { + log.Fatal(err) + } + defer f.Close() + + const size int64 = 1e9 + + log.Printf("writing %v bytes", size) + t1 := time.Now() + n, err := io.Copy(w, io.LimitReader(f, size)) + if err != nil { + log.Fatal(err) + } + if n != size { + log.Fatalf("copy: expected %v bytes, got %d", size, n) + } + log.Printf("wrote %v bytes in %s", size, time.Since(t1)) +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/packet.go b/Godeps/_workspace/src/github.com/pkg/sftp/packet.go new file mode 100644 index 000000000..f2c2b671f --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/packet.go @@ -0,0 +1,331 @@ +package sftp + +import ( + "encoding" + "fmt" + "io" + "reflect" +) + +func marshalUint32(b []byte, v uint32) []byte { + return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +func marshalUint64(b []byte, v uint64) []byte { + return marshalUint32(marshalUint32(b, uint32(v>>32)), uint32(v)) +} + +func marshalString(b []byte, v string) []byte { + return append(marshalUint32(b, uint32(len(v))), v...) +} + +func marshal(b []byte, v interface{}) []byte { + switch v := v.(type) { + case uint8: + return append(b, v) + case uint32: + return marshalUint32(b, v) + case uint64: + return marshalUint64(b, v) + case string: + return marshalString(b, v) + default: + switch d := reflect.ValueOf(v); d.Kind() { + case reflect.Struct: + for i, n := 0, d.NumField(); i < n; i++ { + b = append(marshal(b, d.Field(i).Interface())) + } + return b + case reflect.Slice: + for i, n := 0, d.Len(); i < n; i++ { + b = append(marshal(b, d.Index(i).Interface())) + } + return b + default: + panic(fmt.Sprintf("marshal(%#v): cannot handle type %T", v, v)) + } + } +} + +func unmarshalUint32(b []byte) (uint32, []byte) { + v := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 + return v, b[4:] +} + +func unmarshalUint64(b []byte) (uint64, []byte) { + h, b := unmarshalUint32(b) + l, b := unmarshalUint32(b) + return uint64(h)<<32 | uint64(l), b +} + +func unmarshalString(b []byte) (string, []byte) { + n, b := unmarshalUint32(b) + return string(b[:n]), b[n:] +} + +// sendPacket marshals p according to RFC 4234. + +func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { + bb, err := m.MarshalBinary() + if err != nil { + return fmt.Errorf("marshal2(%#v): binary marshaller failed", err) + } + l := uint32(len(bb)) + hdr := []byte{byte(l >> 24), byte(l >> 16), byte(l >> 8), byte(l)} + debug("send packet %T, len: %v", m, l) + _, err = w.Write(hdr) + if err != nil { + return err + } + _, err = w.Write(bb) + return err +} + +func recvPacket(r io.Reader) (uint8, []byte, error) { + var b = []byte{0, 0, 0, 0} + if _, err := io.ReadFull(r, b); err != nil { + return 0, nil, err + } + l, _ := unmarshalUint32(b) + b = make([]byte, l) + if _, err := io.ReadFull(r, b); err != nil { + return 0, nil, err + } + return b[0], b[1:], nil +} + +// Here starts the definition of packets along with their MarshalBinary +// implementations. +// Manually writing the marshalling logic wins us a lot of time and +// allocation. + +type sshFxInitPacket struct { + Version uint32 + Extensions []struct { + Name, Data string + } +} + +func (p sshFxInitPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 // byte + uint32 + for _, e := range p.Extensions { + l += 4 + len(e.Name) + 4 + len(e.Data) + } + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_INIT) + b = marshalUint32(b, p.Version) + for _, e := range p.Extensions { + b = marshalString(b, e.Name) + b = marshalString(b, e.Data) + } + return b, nil +} + +func marshalIdString(packetType byte, id uint32, str string) ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(str) + + b := make([]byte, 0, l) + b = append(b, packetType) + b = marshalUint32(b, id) + b = marshalString(b, str) + return b, nil +} + +type sshFxpReaddirPacket struct { + Id uint32 + Handle string +} + +func (p sshFxpReaddirPacket) MarshalBinary() ([]byte, error) { + return marshalIdString(ssh_FXP_READDIR, p.Id, p.Handle) +} + +type sshFxpOpendirPacket struct { + Id uint32 + Path string +} + +func (p sshFxpOpendirPacket) MarshalBinary() ([]byte, error) { + return marshalIdString(ssh_FXP_OPENDIR, p.Id, p.Path) +} + +type sshFxpLstatPacket struct { + Id uint32 + Path string +} + +func (p sshFxpLstatPacket) MarshalBinary() ([]byte, error) { + return marshalIdString(ssh_FXP_LSTAT, p.Id, p.Path) +} + +type sshFxpFstatPacket struct { + Id uint32 + Handle string +} + +func (p sshFxpFstatPacket) MarshalBinary() ([]byte, error) { + return marshalIdString(ssh_FXP_FSTAT, p.Id, p.Handle) +} + +type sshFxpClosePacket struct { + Id uint32 + Handle string +} + +func (p sshFxpClosePacket) MarshalBinary() ([]byte, error) { + return marshalIdString(ssh_FXP_CLOSE, p.Id, p.Handle) +} + +type sshFxpRemovePacket struct { + Id uint32 + Filename string +} + +func (p sshFxpRemovePacket) MarshalBinary() ([]byte, error) { + return marshalIdString(ssh_FXP_REMOVE, p.Id, p.Filename) +} + +type sshFxpRmdirPacket struct { + Id uint32 + Path string +} + +func (p sshFxpRmdirPacket) MarshalBinary() ([]byte, error) { + return marshalIdString(ssh_FXP_RMDIR, p.Id, p.Path) +} + +type sshFxpReadlinkPacket struct { + Id uint32 + Path string +} + +func (p sshFxpReadlinkPacket) MarshalBinary() ([]byte, error) { + return marshalIdString(ssh_FXP_READLINK, p.Id, p.Path) +} + +type sshFxpOpenPacket struct { + Id uint32 + Path string + Pflags uint32 + Flags uint32 // ignored +} + +func (p sshFxpOpenPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + + 4 + len(p.Path) + + 4 + 4 + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_OPEN) + b = marshalUint32(b, p.Id) + b = marshalString(b, p.Path) + b = marshalUint32(b, p.Pflags) + b = marshalUint32(b, p.Flags) + return b, nil +} + +type sshFxpReadPacket struct { + Id uint32 + Handle string + Offset uint64 + Len uint32 +} + +func (p sshFxpReadPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(p.Handle) + + 8 + 4 // uint64 + uint32 + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_READ) + b = marshalUint32(b, p.Id) + b = marshalString(b, p.Handle) + b = marshalUint64(b, p.Offset) + b = marshalUint32(b, p.Len) + return b, nil +} + +type sshFxpRenamePacket struct { + Id uint32 + Oldpath string + Newpath string +} + +func (p sshFxpRenamePacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(p.Oldpath) + + 4 + len(p.Newpath) + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_RENAME) + b = marshalUint32(b, p.Id) + b = marshalString(b, p.Oldpath) + b = marshalString(b, p.Newpath) + return b, nil +} + +type sshFxpWritePacket struct { + Id uint32 + Handle string + Offset uint64 + Length uint32 + Data []byte +} + +func (s sshFxpWritePacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(s.Handle) + + 8 + 4 + // uint64 + uint32 + len(s.Data) + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_WRITE) + b = marshalUint32(b, s.Id) + b = marshalString(b, s.Handle) + b = marshalUint64(b, s.Offset) + b = marshalUint32(b, s.Length) + b = append(b, s.Data...) + return b, nil +} + +type sshFxpMkdirPacket struct { + Id uint32 + Path string + Flags uint32 // ignored +} + +func (p sshFxpMkdirPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(p.Path) + + 4 // uint32 + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_MKDIR) + b = marshalUint32(b, p.Id) + b = marshalString(b, p.Path) + b = marshalUint32(b, p.Flags) + return b, nil +} + +type sshFxpSetstatPacket struct { + Id uint32 + Path string + Flags uint32 + Attrs interface{} +} + +func (p sshFxpSetstatPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(p.Path) + + 4 // uint32 + uint64 + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_SETSTAT) + b = marshalUint32(b, p.Id) + b = marshalString(b, p.Path) + b = marshalUint32(b, p.Flags) + b = marshal(b, p.Attrs) + return b, nil +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/packet_test.go b/Godeps/_workspace/src/github.com/pkg/sftp/packet_test.go new file mode 100644 index 000000000..80a1ebf02 --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/packet_test.go @@ -0,0 +1,261 @@ +package sftp + +import ( + "bytes" + "encoding" + "os" + "testing" +) + +var marshalUint32Tests = []struct { + v uint32 + want []byte +}{ + {1, []byte{0, 0, 0, 1}}, + {256, []byte{0, 0, 1, 0}}, + {^uint32(0), []byte{255, 255, 255, 255}}, +} + +func TestMarshalUint32(t *testing.T) { + for _, tt := range marshalUint32Tests { + got := marshalUint32(nil, tt.v) + if !bytes.Equal(tt.want, got) { + t.Errorf("marshalUint32(%d): want %v, got %v", tt.v, tt.want, got) + } + } +} + +var marshalUint64Tests = []struct { + v uint64 + want []byte +}{ + {1, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}}, + {256, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0}}, + {^uint64(0), []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {1 << 32, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}}, +} + +func TestMarshalUint64(t *testing.T) { + for _, tt := range marshalUint64Tests { + got := marshalUint64(nil, tt.v) + if !bytes.Equal(tt.want, got) { + t.Errorf("marshalUint64(%d): want %#v, got %#v", tt.v, tt.want, got) + } + } +} + +var marshalStringTests = []struct { + v string + want []byte +}{ + {"", []byte{0, 0, 0, 0}}, + {"/foo", []byte{0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f}}, +} + +func TestMarshalString(t *testing.T) { + for _, tt := range marshalStringTests { + got := marshalString(nil, tt.v) + if !bytes.Equal(tt.want, got) { + t.Errorf("marshalString(%q): want %#v, got %#v", tt.v, tt.want, got) + } + } +} + +var marshalTests = []struct { + v interface{} + want []byte +}{ + {uint8(1), []byte{1}}, + {byte(1), []byte{1}}, + {uint32(1), []byte{0, 0, 0, 1}}, + {uint64(1), []byte{0, 0, 0, 0, 0, 0, 0, 1}}, + {"foo", []byte{0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f}}, + {[]uint32{1, 2, 3, 4}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x4}}, +} + +func TestMarshal(t *testing.T) { + for _, tt := range marshalTests { + got := marshal(nil, tt.v) + if !bytes.Equal(tt.want, got) { + t.Errorf("marshal(%v): want %#v, got %#v", tt.v, tt.want, got) + } + } +} + +var unmarshalUint32Tests = []struct { + b []byte + want uint32 + rest []byte +}{ + {[]byte{0, 0, 0, 0}, 0, nil}, + {[]byte{0, 0, 1, 0}, 256, nil}, + {[]byte{255, 0, 0, 255}, 4278190335, nil}, +} + +func TestUnmarshalUint32(t *testing.T) { + for _, tt := range unmarshalUint32Tests { + got, rest := unmarshalUint32(tt.b) + if got != tt.want || !bytes.Equal(rest, tt.rest) { + t.Errorf("unmarshalUint32(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) + } + } +} + +var unmarshalUint64Tests = []struct { + b []byte + want uint64 + rest []byte +}{ + {[]byte{0, 0, 0, 0, 0, 0, 0, 0}, 0, nil}, + {[]byte{0, 0, 0, 0, 0, 0, 1, 0}, 256, nil}, + {[]byte{255, 0, 0, 0, 0, 0, 0, 255}, 18374686479671623935, nil}, +} + +func TestUnmarshalUint64(t *testing.T) { + for _, tt := range unmarshalUint64Tests { + got, rest := unmarshalUint64(tt.b) + if got != tt.want || !bytes.Equal(rest, tt.rest) { + t.Errorf("unmarshalUint64(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) + } + } +} + +var unmarshalStringTests = []struct { + b []byte + want string + rest []byte +}{ + {marshalString(nil, ""), "", nil}, + {marshalString(nil, "blah"), "blah", nil}, +} + +func TestUnmarshalString(t *testing.T) { + for _, tt := range unmarshalStringTests { + got, rest := unmarshalString(tt.b) + if got != tt.want || !bytes.Equal(rest, tt.rest) { + t.Errorf("unmarshalUint64(%v): want %q, %#v, got %q, %#v", tt.b, tt.want, tt.rest, got, rest) + } + } +} + +var sendPacketTests = []struct { + p encoding.BinaryMarshaler + want []byte +}{ + {sshFxInitPacket{ + Version: 3, + Extensions: []struct{ Name, Data string }{ + {"posix-rename@openssh.com", "1"}, + }, + }, []byte{0x0, 0x0, 0x0, 0x26, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}}, + + {sshFxpOpenPacket{ + Id: 1, + Path: "/foo", + Pflags: flags(os.O_RDONLY), + }, []byte{0x0, 0x0, 0x0, 0x15, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}}, + + {sshFxpWritePacket{ + Id: 124, + Handle: "foo", + Offset: 13, + Length: uint32(len([]byte("bar"))), + Data: []byte("bar"), + }, []byte{0x0, 0x0, 0x0, 0x1b, 0x6, 0x0, 0x0, 0x0, 0x7c, 0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd, 0x0, 0x0, 0x0, 0x3, 0x62, 0x61, 0x72}}, + + {sshFxpSetstatPacket{ + Id: 31, + Path: "/bar", + Flags: flags(os.O_WRONLY), + Attrs: struct { + Uid uint32 + Gid uint32 + }{1000, 100}, + }, []byte{0x0, 0x0, 0x0, 0x19, 0x9, 0x0, 0x0, 0x0, 0x1f, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x62, 0x61, 0x72, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x3, 0xe8, 0x0, 0x0, 0x0, 0x64}}, +} + +func TestSendPacket(t *testing.T) { + for _, tt := range sendPacketTests { + var w bytes.Buffer + sendPacket(&w, tt.p) + if got := w.Bytes(); !bytes.Equal(tt.want, got) { + t.Errorf("sendPacket(%v): want %#v, got %#v", tt.p, tt.want, got) + } + } +} + +func sp(p encoding.BinaryMarshaler) []byte { + var w bytes.Buffer + sendPacket(&w, p) + return w.Bytes() +} + +var recvPacketTests = []struct { + b []byte + want uint8 + rest []byte +}{ + {sp(sshFxInitPacket{ + Version: 3, + Extensions: []struct{ Name, Data string }{ + {"posix-rename@openssh.com", "1"}, + }, + }), ssh_FXP_INIT, []byte{0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}}, +} + +func TestRecvPacket(t *testing.T) { + for _, tt := range recvPacketTests { + r := bytes.NewReader(tt.b) + got, rest, _ := recvPacket(r) + if got != tt.want || !bytes.Equal(rest, tt.rest) { + t.Errorf("recvPacket(%#v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) + } + } +} + +func BenchmarkMarshalInit(b *testing.B) { + for i := 0; i < b.N; i++ { + sp(sshFxInitPacket{ + Version: 3, + Extensions: []struct{ Name, Data string }{ + {"posix-rename@openssh.com", "1"}, + }, + }) + } +} + +func BenchmarkMarshalOpen(b *testing.B) { + for i := 0; i < b.N; i++ { + sp(sshFxpOpenPacket{ + Id: 1, + Path: "/home/test/some/random/path", + Pflags: flags(os.O_RDONLY), + }) + } +} + +func BenchmarkMarshalWriteWorstCase(b *testing.B) { + data := make([]byte, 32*1024) + for i := 0; i < b.N; i++ { + sp(sshFxpWritePacket{ + Id: 1, + Handle: "someopaquehandle", + Offset: 0, + Length: uint32(len(data)), + Data: data, + }) + } +} + +func BenchmarkMarshalWrite1k(b *testing.B) { + data := make([]byte, 1024) + for i := 0; i < b.N; i++ { + sp(sshFxpWritePacket{ + Id: 1, + Handle: "someopaquehandle", + Offset: 0, + Length: uint32(len(data)), + Data: data, + }) + } +} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/release.go b/Godeps/_workspace/src/github.com/pkg/sftp/release.go new file mode 100644 index 000000000..b695528fd --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/release.go @@ -0,0 +1,5 @@ +// +build !debug + +package sftp + +func debug(fmt string, args ...interface{}) {} diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/sftp.go b/Godeps/_workspace/src/github.com/pkg/sftp/sftp.go new file mode 100644 index 000000000..934684c23 --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/sftp.go @@ -0,0 +1,187 @@ +// Package sftp implements the SSH File Transfer Protocol as described in +// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt +package sftp + +import ( + "fmt" +) + +const ( + ssh_FXP_INIT = 1 + ssh_FXP_VERSION = 2 + ssh_FXP_OPEN = 3 + ssh_FXP_CLOSE = 4 + ssh_FXP_READ = 5 + ssh_FXP_WRITE = 6 + ssh_FXP_LSTAT = 7 + ssh_FXP_FSTAT = 8 + ssh_FXP_SETSTAT = 9 + ssh_FXP_FSETSTAT = 10 + ssh_FXP_OPENDIR = 11 + ssh_FXP_READDIR = 12 + ssh_FXP_REMOVE = 13 + ssh_FXP_MKDIR = 14 + ssh_FXP_RMDIR = 15 + ssh_FXP_REALPATH = 16 + ssh_FXP_STAT = 17 + ssh_FXP_RENAME = 18 + ssh_FXP_READLINK = 19 + ssh_FXP_SYMLINK = 20 + ssh_FXP_STATUS = 101 + ssh_FXP_HANDLE = 102 + ssh_FXP_DATA = 103 + ssh_FXP_NAME = 104 + ssh_FXP_ATTRS = 105 + ssh_FXP_EXTENDED = 200 + ssh_FXP_EXTENDED_REPLY = 201 +) + +const ( + ssh_FX_OK = 0 + ssh_FX_EOF = 1 + ssh_FX_NO_SUCH_FILE = 2 + ssh_FX_PERMISSION_DENIED = 3 + ssh_FX_FAILURE = 4 + ssh_FX_BAD_MESSAGE = 5 + ssh_FX_NO_CONNECTION = 6 + ssh_FX_CONNECTION_LOST = 7 + ssh_FX_OP_UNSUPPORTED = 8 +) + +const ( + ssh_FXF_READ = 0x00000001 + ssh_FXF_WRITE = 0x00000002 + ssh_FXF_APPEND = 0x00000004 + ssh_FXF_CREAT = 0x00000008 + ssh_FXF_TRUNC = 0x00000010 + ssh_FXF_EXCL = 0x00000020 +) + +type fxp uint8 + +func (f fxp) String() string { + switch f { + case ssh_FXP_INIT: + return "SSH_FXP_INIT" + case ssh_FXP_VERSION: + return "SSH_FXP_VERSION" + case ssh_FXP_OPEN: + return "SSH_FXP_OPEN" + case ssh_FXP_CLOSE: + return "SSH_FXP_CLOSE" + case ssh_FXP_READ: + return "SSH_FXP_READ" + case ssh_FXP_WRITE: + return "SSH_FXP_WRITE" + case ssh_FXP_LSTAT: + return "SSH_FXP_LSTAT" + case ssh_FXP_FSTAT: + return "SSH_FXP_FSTAT" + case ssh_FXP_SETSTAT: + return "SSH_FXP_SETSTAT" + case ssh_FXP_FSETSTAT: + return "SSH_FXP_FSETSTAT" + case ssh_FXP_OPENDIR: + return "SSH_FXP_OPENDIR" + case ssh_FXP_READDIR: + return "SSH_FXP_READDIR" + case ssh_FXP_REMOVE: + return "SSH_FXP_REMOVE" + case ssh_FXP_MKDIR: + return "SSH_FXP_MKDIR" + case ssh_FXP_RMDIR: + return "SSH_FXP_RMDIR" + case ssh_FXP_REALPATH: + return "SSH_FXP_REALPATH" + case ssh_FXP_STAT: + return "SSH_FXP_STAT" + case ssh_FXP_RENAME: + return "SSH_FXP_RENAME" + case ssh_FXP_READLINK: + return "SSH_FXP_READLINK" + case ssh_FXP_SYMLINK: + return "SSH_FXP_SYMLINK" + case ssh_FXP_STATUS: + return "SSH_FXP_STATUS" + case ssh_FXP_HANDLE: + return "SSH_FXP_HANDLE" + case ssh_FXP_DATA: + return "SSH_FXP_DATA" + case ssh_FXP_NAME: + return "SSH_FXP_NAME" + case ssh_FXP_ATTRS: + return "SSH_FXP_ATTRS" + case ssh_FXP_EXTENDED: + return "SSH_FXP_EXTENDED" + case ssh_FXP_EXTENDED_REPLY: + return "SSH_FXP_EXTENDED_REPLY" + default: + return "unknown" + } +} + +type fx uint8 + +func (f fx) String() string { + switch f { + case ssh_FX_OK: + return "SSH_FX_OK" + case ssh_FX_EOF: + return "SSH_FX_EOF" + case ssh_FX_NO_SUCH_FILE: + return "SSH_FX_NO_SUCH_FILE" + case ssh_FX_PERMISSION_DENIED: + return "SSH_FX_PERMISSION_DENIED" + case ssh_FX_FAILURE: + return "SSH_FX_FAILURE" + case ssh_FX_BAD_MESSAGE: + return "SSH_FX_BAD_MESSAGE" + case ssh_FX_NO_CONNECTION: + return "SSH_FX_NO_CONNECTION" + case ssh_FX_CONNECTION_LOST: + return "SSH_FX_CONNECTION_LOST" + case ssh_FX_OP_UNSUPPORTED: + return "SSH_FX_OP_UNSUPPORTED" + default: + return "unknown" + } +} + +type unexpectedPacketErr struct { + want, got uint8 +} + +func (u *unexpectedPacketErr) Error() string { + return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got)) +} + +func unimplementedPacketErr(u uint8) error { + return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u)) +} + +type unexpectedIdErr struct{ want, got uint32 } + +func (u *unexpectedIdErr) Error() string { + return fmt.Sprintf("sftp: unexpected id: want %v, got %v", u.want, u.got) +} + +func unimplementedSeekWhence(whence int) error { + return fmt.Errorf("sftp: unimplemented seek whence %v", whence) +} + +func unexpectedCount(want, got uint32) error { + return fmt.Errorf("sftp: unexpected count: want %v, got %v", want, got) +} + +type unexpectedVersionErr struct{ want, got uint32 } + +func (u *unexpectedVersionErr) Error() string { + return fmt.Sprintf("sftp: unexpected server version: want %v, got %v", u.want, u.got) +} + +type StatusError struct { + Code uint32 + msg, lang string +} + +func (s *StatusError) Error() string { return fmt.Sprintf("sftp: %q (%v)", s.msg, fx(s.Code)) } diff --git a/Godeps/_workspace/src/github.com/pkg/sftp/wercker.yml b/Godeps/_workspace/src/github.com/pkg/sftp/wercker.yml new file mode 100644 index 000000000..41d2c5280 --- /dev/null +++ b/Godeps/_workspace/src/github.com/pkg/sftp/wercker.yml @@ -0,0 +1 @@ +box: wercker/golang diff --git a/Godeps/_workspace/src/golang.org/x/crypto/pbkdf2/pbkdf2.go b/Godeps/_workspace/src/golang.org/x/crypto/pbkdf2/pbkdf2.go new file mode 100644 index 000000000..593f65300 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/pbkdf2/pbkdf2.go @@ -0,0 +1,77 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package pbkdf2 implements the key derivation function PBKDF2 as defined in RFC +2898 / PKCS #5 v2.0. + +A key derivation function is useful when encrypting data based on a password +or any other not-fully-random data. It uses a pseudorandom function to derive +a secure encryption key based on the password. + +While v2.0 of the standard defines only one pseudorandom function to use, +HMAC-SHA1, the drafted v2.1 specification allows use of all five FIPS Approved +Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for HMAC. To +choose, you can pass the `New` functions from the different SHA packages to +pbkdf2.Key. +*/ +package pbkdf2 // import "golang.org/x/crypto/pbkdf2" + +import ( + "crypto/hmac" + "hash" +) + +// Key derives a key from the password, salt and iteration count, returning a +// []byte of length keylen that can be used as cryptographic key. The key is +// derived based on the method described as PBKDF2 with the HMAC variant using +// the supplied hash function. +// +// For example, to use a HMAC-SHA-1 based PBKDF2 key derivation function, you +// can get a derived key for e.g. AES-256 (which needs a 32-byte key) by +// doing: +// +// dk := pbkdf2.Key([]byte("some password"), salt, 4096, 32, sha1.New) +// +// Remember to get a good random salt. At least 8 bytes is recommended by the +// RFC. +// +// Using a higher iteration count will increase the cost of an exhaustive +// search but will also make derivation proportionally slower. +func Key(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte { + prf := hmac.New(h, password) + hashLen := prf.Size() + numBlocks := (keyLen + hashLen - 1) / hashLen + + var buf [4]byte + dk := make([]byte, 0, numBlocks*hashLen) + U := make([]byte, hashLen) + for block := 1; block <= numBlocks; block++ { + // N.B.: || means concatenation, ^ means XOR + // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter + // U_1 = PRF(password, salt || uint(i)) + prf.Reset() + prf.Write(salt) + buf[0] = byte(block >> 24) + buf[1] = byte(block >> 16) + buf[2] = byte(block >> 8) + buf[3] = byte(block) + prf.Write(buf[:4]) + dk = prf.Sum(dk) + T := dk[len(dk)-hashLen:] + copy(U, T) + + // U_n = PRF(password, U_(n-1)) + for n := 2; n <= iter; n++ { + prf.Reset() + prf.Write(U) + U = U[:0] + U = prf.Sum(U) + for x := range U { + T[x] ^= U[x] + } + } + } + return dk[:keyLen] +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/pbkdf2/pbkdf2_test.go b/Godeps/_workspace/src/golang.org/x/crypto/pbkdf2/pbkdf2_test.go new file mode 100644 index 000000000..137924061 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/pbkdf2/pbkdf2_test.go @@ -0,0 +1,157 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pbkdf2 + +import ( + "bytes" + "crypto/sha1" + "crypto/sha256" + "hash" + "testing" +) + +type testVector struct { + password string + salt string + iter int + output []byte +} + +// Test vectors from RFC 6070, http://tools.ietf.org/html/rfc6070 +var sha1TestVectors = []testVector{ + { + "password", + "salt", + 1, + []byte{ + 0x0c, 0x60, 0xc8, 0x0f, 0x96, 0x1f, 0x0e, 0x71, + 0xf3, 0xa9, 0xb5, 0x24, 0xaf, 0x60, 0x12, 0x06, + 0x2f, 0xe0, 0x37, 0xa6, + }, + }, + { + "password", + "salt", + 2, + []byte{ + 0xea, 0x6c, 0x01, 0x4d, 0xc7, 0x2d, 0x6f, 0x8c, + 0xcd, 0x1e, 0xd9, 0x2a, 0xce, 0x1d, 0x41, 0xf0, + 0xd8, 0xde, 0x89, 0x57, + }, + }, + { + "password", + "salt", + 4096, + []byte{ + 0x4b, 0x00, 0x79, 0x01, 0xb7, 0x65, 0x48, 0x9a, + 0xbe, 0xad, 0x49, 0xd9, 0x26, 0xf7, 0x21, 0xd0, + 0x65, 0xa4, 0x29, 0xc1, + }, + }, + // // This one takes too long + // { + // "password", + // "salt", + // 16777216, + // []byte{ + // 0xee, 0xfe, 0x3d, 0x61, 0xcd, 0x4d, 0xa4, 0xe4, + // 0xe9, 0x94, 0x5b, 0x3d, 0x6b, 0xa2, 0x15, 0x8c, + // 0x26, 0x34, 0xe9, 0x84, + // }, + // }, + { + "passwordPASSWORDpassword", + "saltSALTsaltSALTsaltSALTsaltSALTsalt", + 4096, + []byte{ + 0x3d, 0x2e, 0xec, 0x4f, 0xe4, 0x1c, 0x84, 0x9b, + 0x80, 0xc8, 0xd8, 0x36, 0x62, 0xc0, 0xe4, 0x4a, + 0x8b, 0x29, 0x1a, 0x96, 0x4c, 0xf2, 0xf0, 0x70, + 0x38, + }, + }, + { + "pass\000word", + "sa\000lt", + 4096, + []byte{ + 0x56, 0xfa, 0x6a, 0xa7, 0x55, 0x48, 0x09, 0x9d, + 0xcc, 0x37, 0xd7, 0xf0, 0x34, 0x25, 0xe0, 0xc3, + }, + }, +} + +// Test vectors from +// http://stackoverflow.com/questions/5130513/pbkdf2-hmac-sha2-test-vectors +var sha256TestVectors = []testVector{ + { + "password", + "salt", + 1, + []byte{ + 0x12, 0x0f, 0xb6, 0xcf, 0xfc, 0xf8, 0xb3, 0x2c, + 0x43, 0xe7, 0x22, 0x52, 0x56, 0xc4, 0xf8, 0x37, + 0xa8, 0x65, 0x48, 0xc9, + }, + }, + { + "password", + "salt", + 2, + []byte{ + 0xae, 0x4d, 0x0c, 0x95, 0xaf, 0x6b, 0x46, 0xd3, + 0x2d, 0x0a, 0xdf, 0xf9, 0x28, 0xf0, 0x6d, 0xd0, + 0x2a, 0x30, 0x3f, 0x8e, + }, + }, + { + "password", + "salt", + 4096, + []byte{ + 0xc5, 0xe4, 0x78, 0xd5, 0x92, 0x88, 0xc8, 0x41, + 0xaa, 0x53, 0x0d, 0xb6, 0x84, 0x5c, 0x4c, 0x8d, + 0x96, 0x28, 0x93, 0xa0, + }, + }, + { + "passwordPASSWORDpassword", + "saltSALTsaltSALTsaltSALTsaltSALTsalt", + 4096, + []byte{ + 0x34, 0x8c, 0x89, 0xdb, 0xcb, 0xd3, 0x2b, 0x2f, + 0x32, 0xd8, 0x14, 0xb8, 0x11, 0x6e, 0x84, 0xcf, + 0x2b, 0x17, 0x34, 0x7e, 0xbc, 0x18, 0x00, 0x18, + 0x1c, + }, + }, + { + "pass\000word", + "sa\000lt", + 4096, + []byte{ + 0x89, 0xb6, 0x9d, 0x05, 0x16, 0xf8, 0x29, 0x89, + 0x3c, 0x69, 0x62, 0x26, 0x65, 0x0a, 0x86, 0x87, + }, + }, +} + +func testHash(t *testing.T, h func() hash.Hash, hashName string, vectors []testVector) { + for i, v := range vectors { + o := Key([]byte(v.password), []byte(v.salt), v.iter, len(v.output), h) + if !bytes.Equal(o, v.output) { + t.Errorf("%s %d: expected %x, got %x", hashName, i, v.output, o) + } + } +} + +func TestWithHMACSHA1(t *testing.T) { + testHash(t, sha1.New, "SHA1", sha1TestVectors) +} + +func TestWithHMACSHA256(t *testing.T) { + testHash(t, sha256.New, "SHA256", sha256TestVectors) +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/poly1305/const_amd64.s b/Godeps/_workspace/src/golang.org/x/crypto/poly1305/const_amd64.s new file mode 100644 index 000000000..8e861f337 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/poly1305/const_amd64.s @@ -0,0 +1,45 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This code was translated into a form compatible with 6a from the public +// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html + +// +build amd64,!gccgo,!appengine + +DATA ·SCALE(SB)/8, $0x37F4000000000000 +GLOBL ·SCALE(SB), 8, $8 +DATA ·TWO32(SB)/8, $0x41F0000000000000 +GLOBL ·TWO32(SB), 8, $8 +DATA ·TWO64(SB)/8, $0x43F0000000000000 +GLOBL ·TWO64(SB), 8, $8 +DATA ·TWO96(SB)/8, $0x45F0000000000000 +GLOBL ·TWO96(SB), 8, $8 +DATA ·ALPHA32(SB)/8, $0x45E8000000000000 +GLOBL ·ALPHA32(SB), 8, $8 +DATA ·ALPHA64(SB)/8, $0x47E8000000000000 +GLOBL ·ALPHA64(SB), 8, $8 +DATA ·ALPHA96(SB)/8, $0x49E8000000000000 +GLOBL ·ALPHA96(SB), 8, $8 +DATA ·ALPHA130(SB)/8, $0x4C08000000000000 +GLOBL ·ALPHA130(SB), 8, $8 +DATA ·DOFFSET0(SB)/8, $0x4330000000000000 +GLOBL ·DOFFSET0(SB), 8, $8 +DATA ·DOFFSET1(SB)/8, $0x4530000000000000 +GLOBL ·DOFFSET1(SB), 8, $8 +DATA ·DOFFSET2(SB)/8, $0x4730000000000000 +GLOBL ·DOFFSET2(SB), 8, $8 +DATA ·DOFFSET3(SB)/8, $0x4930000000000000 +GLOBL ·DOFFSET3(SB), 8, $8 +DATA ·DOFFSET3MINUSTWO128(SB)/8, $0x492FFFFE00000000 +GLOBL ·DOFFSET3MINUSTWO128(SB), 8, $8 +DATA ·HOFFSET0(SB)/8, $0x43300001FFFFFFFB +GLOBL ·HOFFSET0(SB), 8, $8 +DATA ·HOFFSET1(SB)/8, $0x45300001FFFFFFFE +GLOBL ·HOFFSET1(SB), 8, $8 +DATA ·HOFFSET2(SB)/8, $0x47300001FFFFFFFE +GLOBL ·HOFFSET2(SB), 8, $8 +DATA ·HOFFSET3(SB)/8, $0x49300003FFFFFFFE +GLOBL ·HOFFSET3(SB), 8, $8 +DATA ·ROUNDING(SB)/2, $0x137f +GLOBL ·ROUNDING(SB), 8, $2 diff --git a/Godeps/_workspace/src/golang.org/x/crypto/poly1305/poly1305.go b/Godeps/_workspace/src/golang.org/x/crypto/poly1305/poly1305.go new file mode 100644 index 000000000..4a5f826f7 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/poly1305/poly1305.go @@ -0,0 +1,32 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package poly1305 implements Poly1305 one-time message authentication code as specified in http://cr.yp.to/mac/poly1305-20050329.pdf. + +Poly1305 is a fast, one-time authentication function. It is infeasible for an +attacker to generate an authenticator for a message without the key. However, a +key must only be used for a single message. Authenticating two different +messages with the same key allows an attacker to forge authenticators for other +messages with the same key. + +Poly1305 was originally coupled with AES in order to make Poly1305-AES. AES was +used with a fixed key in order to generate one-time keys from an nonce. +However, in this package AES isn't used and the one-time key is specified +directly. +*/ +package poly1305 // import "golang.org/x/crypto/poly1305" + +import "crypto/subtle" + +// TagSize is the size, in bytes, of a poly1305 authenticator. +const TagSize = 16 + +// Verify returns true if mac is a valid authenticator for m with the given +// key. +func Verify(mac *[16]byte, m []byte, key *[32]byte) bool { + var tmp [16]byte + Sum(&tmp, m, key) + return subtle.ConstantTimeCompare(tmp[:], mac[:]) == 1 +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/poly1305/poly1305_amd64.s b/Godeps/_workspace/src/golang.org/x/crypto/poly1305/poly1305_amd64.s new file mode 100644 index 000000000..f8d4ee928 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/poly1305/poly1305_amd64.s @@ -0,0 +1,497 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This code was translated into a form compatible with 6a from the public +// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html + +// +build amd64,!gccgo,!appengine + +// func poly1305(out *[16]byte, m *byte, mlen uint64, key *[32]key) +TEXT ·poly1305(SB),0,$224-32 + MOVQ out+0(FP),DI + MOVQ m+8(FP),SI + MOVQ mlen+16(FP),DX + MOVQ key+24(FP),CX + + MOVQ SP,R11 + MOVQ $31,R9 + NOTQ R9 + ANDQ R9,SP + ADDQ $32,SP + + MOVQ R11,32(SP) + MOVQ R12,40(SP) + MOVQ R13,48(SP) + MOVQ R14,56(SP) + MOVQ R15,64(SP) + MOVQ BX,72(SP) + MOVQ BP,80(SP) + FLDCW ·ROUNDING(SB) + MOVL 0(CX),R8 + MOVL 4(CX),R9 + MOVL 8(CX),AX + MOVL 12(CX),R10 + MOVQ DI,88(SP) + MOVQ CX,96(SP) + MOVL $0X43300000,108(SP) + MOVL $0X45300000,116(SP) + MOVL $0X47300000,124(SP) + MOVL $0X49300000,132(SP) + ANDL $0X0FFFFFFF,R8 + ANDL $0X0FFFFFFC,R9 + ANDL $0X0FFFFFFC,AX + ANDL $0X0FFFFFFC,R10 + MOVL R8,104(SP) + MOVL R9,112(SP) + MOVL AX,120(SP) + MOVL R10,128(SP) + FMOVD 104(SP), F0 + FSUBD ·DOFFSET0(SB), F0 + FMOVD 112(SP), F0 + FSUBD ·DOFFSET1(SB), F0 + FMOVD 120(SP), F0 + FSUBD ·DOFFSET2(SB), F0 + FMOVD 128(SP), F0 + FSUBD ·DOFFSET3(SB), F0 + FXCHD F0, F3 + FMOVDP F0, 136(SP) + FXCHD F0, F1 + FMOVD F0, 144(SP) + FMULD ·SCALE(SB), F0 + FMOVDP F0, 152(SP) + FMOVD F0, 160(SP) + FMULD ·SCALE(SB), F0 + FMOVDP F0, 168(SP) + FMOVD F0, 176(SP) + FMULD ·SCALE(SB), F0 + FMOVDP F0, 184(SP) + FLDZ + FLDZ + FLDZ + FLDZ + CMPQ DX,$16 + JB ADDATMOST15BYTES + INITIALATLEAST16BYTES: + MOVL 12(SI),DI + MOVL 8(SI),CX + MOVL 4(SI),R8 + MOVL 0(SI),R9 + MOVL DI,128(SP) + MOVL CX,120(SP) + MOVL R8,112(SP) + MOVL R9,104(SP) + ADDQ $16,SI + SUBQ $16,DX + FXCHD F0, F3 + FADDD 128(SP), F0 + FSUBD ·DOFFSET3MINUSTWO128(SB), F0 + FXCHD F0, F1 + FADDD 112(SP), F0 + FSUBD ·DOFFSET1(SB), F0 + FXCHD F0, F2 + FADDD 120(SP), F0 + FSUBD ·DOFFSET2(SB), F0 + FXCHD F0, F3 + FADDD 104(SP), F0 + FSUBD ·DOFFSET0(SB), F0 + CMPQ DX,$16 + JB MULTIPLYADDATMOST15BYTES + MULTIPLYADDATLEAST16BYTES: + MOVL 12(SI),DI + MOVL 8(SI),CX + MOVL 4(SI),R8 + MOVL 0(SI),R9 + MOVL DI,128(SP) + MOVL CX,120(SP) + MOVL R8,112(SP) + MOVL R9,104(SP) + ADDQ $16,SI + SUBQ $16,DX + FMOVD ·ALPHA130(SB), F0 + FADDD F2,F0 + FSUBD ·ALPHA130(SB), F0 + FSUBD F0,F2 + FMULD ·SCALE(SB), F0 + FMOVD ·ALPHA32(SB), F0 + FADDD F2,F0 + FSUBD ·ALPHA32(SB), F0 + FSUBD F0,F2 + FXCHD F0, F2 + FADDDP F0,F1 + FMOVD ·ALPHA64(SB), F0 + FADDD F4,F0 + FSUBD ·ALPHA64(SB), F0 + FSUBD F0,F4 + FMOVD ·ALPHA96(SB), F0 + FADDD F6,F0 + FSUBD ·ALPHA96(SB), F0 + FSUBD F0,F6 + FXCHD F0, F6 + FADDDP F0,F1 + FXCHD F0, F3 + FADDDP F0,F5 + FXCHD F0, F3 + FADDDP F0,F1 + FMOVD 176(SP), F0 + FMULD F3,F0 + FMOVD 160(SP), F0 + FMULD F4,F0 + FMOVD 144(SP), F0 + FMULD F5,F0 + FMOVD 136(SP), F0 + FMULDP F0,F6 + FMOVD 160(SP), F0 + FMULD F4,F0 + FADDDP F0,F3 + FMOVD 144(SP), F0 + FMULD F4,F0 + FADDDP F0,F2 + FMOVD 136(SP), F0 + FMULD F4,F0 + FADDDP F0,F1 + FMOVD 184(SP), F0 + FMULDP F0,F4 + FXCHD F0, F3 + FADDDP F0,F5 + FMOVD 144(SP), F0 + FMULD F4,F0 + FADDDP F0,F2 + FMOVD 136(SP), F0 + FMULD F4,F0 + FADDDP F0,F1 + FMOVD 184(SP), F0 + FMULD F4,F0 + FADDDP F0,F3 + FMOVD 168(SP), F0 + FMULDP F0,F4 + FXCHD F0, F3 + FADDDP F0,F4 + FMOVD 136(SP), F0 + FMULD F5,F0 + FADDDP F0,F1 + FXCHD F0, F3 + FMOVD 184(SP), F0 + FMULD F5,F0 + FADDDP F0,F3 + FXCHD F0, F1 + FMOVD 168(SP), F0 + FMULD F5,F0 + FADDDP F0,F1 + FMOVD 152(SP), F0 + FMULDP F0,F5 + FXCHD F0, F4 + FADDDP F0,F1 + CMPQ DX,$16 + FXCHD F0, F2 + FMOVD 128(SP), F0 + FSUBD ·DOFFSET3MINUSTWO128(SB), F0 + FADDDP F0,F1 + FXCHD F0, F1 + FMOVD 120(SP), F0 + FSUBD ·DOFFSET2(SB), F0 + FADDDP F0,F1 + FXCHD F0, F3 + FMOVD 112(SP), F0 + FSUBD ·DOFFSET1(SB), F0 + FADDDP F0,F1 + FXCHD F0, F2 + FMOVD 104(SP), F0 + FSUBD ·DOFFSET0(SB), F0 + FADDDP F0,F1 + JAE MULTIPLYADDATLEAST16BYTES + MULTIPLYADDATMOST15BYTES: + FMOVD ·ALPHA130(SB), F0 + FADDD F2,F0 + FSUBD ·ALPHA130(SB), F0 + FSUBD F0,F2 + FMULD ·SCALE(SB), F0 + FMOVD ·ALPHA32(SB), F0 + FADDD F2,F0 + FSUBD ·ALPHA32(SB), F0 + FSUBD F0,F2 + FMOVD ·ALPHA64(SB), F0 + FADDD F5,F0 + FSUBD ·ALPHA64(SB), F0 + FSUBD F0,F5 + FMOVD ·ALPHA96(SB), F0 + FADDD F7,F0 + FSUBD ·ALPHA96(SB), F0 + FSUBD F0,F7 + FXCHD F0, F7 + FADDDP F0,F1 + FXCHD F0, F5 + FADDDP F0,F1 + FXCHD F0, F3 + FADDDP F0,F5 + FADDDP F0,F1 + FMOVD 176(SP), F0 + FMULD F1,F0 + FMOVD 160(SP), F0 + FMULD F2,F0 + FMOVD 144(SP), F0 + FMULD F3,F0 + FMOVD 136(SP), F0 + FMULDP F0,F4 + FMOVD 160(SP), F0 + FMULD F5,F0 + FADDDP F0,F3 + FMOVD 144(SP), F0 + FMULD F5,F0 + FADDDP F0,F2 + FMOVD 136(SP), F0 + FMULD F5,F0 + FADDDP F0,F1 + FMOVD 184(SP), F0 + FMULDP F0,F5 + FXCHD F0, F4 + FADDDP F0,F3 + FMOVD 144(SP), F0 + FMULD F5,F0 + FADDDP F0,F2 + FMOVD 136(SP), F0 + FMULD F5,F0 + FADDDP F0,F1 + FMOVD 184(SP), F0 + FMULD F5,F0 + FADDDP F0,F4 + FMOVD 168(SP), F0 + FMULDP F0,F5 + FXCHD F0, F4 + FADDDP F0,F2 + FMOVD 136(SP), F0 + FMULD F5,F0 + FADDDP F0,F1 + FMOVD 184(SP), F0 + FMULD F5,F0 + FADDDP F0,F4 + FMOVD 168(SP), F0 + FMULD F5,F0 + FADDDP F0,F3 + FMOVD 152(SP), F0 + FMULDP F0,F5 + FXCHD F0, F4 + FADDDP F0,F1 + ADDATMOST15BYTES: + CMPQ DX,$0 + JE NOMOREBYTES + MOVL $0,0(SP) + MOVL $0, 4 (SP) + MOVL $0, 8 (SP) + MOVL $0, 12 (SP) + LEAQ 0(SP),DI + MOVQ DX,CX + REP; MOVSB + MOVB $1,0(DI) + MOVL 12 (SP),DI + MOVL 8 (SP),SI + MOVL 4 (SP),DX + MOVL 0(SP),CX + MOVL DI,128(SP) + MOVL SI,120(SP) + MOVL DX,112(SP) + MOVL CX,104(SP) + FXCHD F0, F3 + FADDD 128(SP), F0 + FSUBD ·DOFFSET3(SB), F0 + FXCHD F0, F2 + FADDD 120(SP), F0 + FSUBD ·DOFFSET2(SB), F0 + FXCHD F0, F1 + FADDD 112(SP), F0 + FSUBD ·DOFFSET1(SB), F0 + FXCHD F0, F3 + FADDD 104(SP), F0 + FSUBD ·DOFFSET0(SB), F0 + FMOVD ·ALPHA130(SB), F0 + FADDD F3,F0 + FSUBD ·ALPHA130(SB), F0 + FSUBD F0,F3 + FMULD ·SCALE(SB), F0 + FMOVD ·ALPHA32(SB), F0 + FADDD F2,F0 + FSUBD ·ALPHA32(SB), F0 + FSUBD F0,F2 + FMOVD ·ALPHA64(SB), F0 + FADDD F6,F0 + FSUBD ·ALPHA64(SB), F0 + FSUBD F0,F6 + FMOVD ·ALPHA96(SB), F0 + FADDD F5,F0 + FSUBD ·ALPHA96(SB), F0 + FSUBD F0,F5 + FXCHD F0, F4 + FADDDP F0,F3 + FXCHD F0, F6 + FADDDP F0,F1 + FXCHD F0, F3 + FADDDP F0,F5 + FXCHD F0, F3 + FADDDP F0,F1 + FMOVD 176(SP), F0 + FMULD F3,F0 + FMOVD 160(SP), F0 + FMULD F4,F0 + FMOVD 144(SP), F0 + FMULD F5,F0 + FMOVD 136(SP), F0 + FMULDP F0,F6 + FMOVD 160(SP), F0 + FMULD F5,F0 + FADDDP F0,F3 + FMOVD 144(SP), F0 + FMULD F5,F0 + FADDDP F0,F2 + FMOVD 136(SP), F0 + FMULD F5,F0 + FADDDP F0,F1 + FMOVD 184(SP), F0 + FMULDP F0,F5 + FXCHD F0, F4 + FADDDP F0,F5 + FMOVD 144(SP), F0 + FMULD F6,F0 + FADDDP F0,F2 + FMOVD 136(SP), F0 + FMULD F6,F0 + FADDDP F0,F1 + FMOVD 184(SP), F0 + FMULD F6,F0 + FADDDP F0,F4 + FMOVD 168(SP), F0 + FMULDP F0,F6 + FXCHD F0, F5 + FADDDP F0,F4 + FMOVD 136(SP), F0 + FMULD F2,F0 + FADDDP F0,F1 + FMOVD 184(SP), F0 + FMULD F2,F0 + FADDDP F0,F5 + FMOVD 168(SP), F0 + FMULD F2,F0 + FADDDP F0,F3 + FMOVD 152(SP), F0 + FMULDP F0,F2 + FXCHD F0, F1 + FADDDP F0,F3 + FXCHD F0, F3 + FXCHD F0, F2 + NOMOREBYTES: + MOVL $0,R10 + FMOVD ·ALPHA130(SB), F0 + FADDD F4,F0 + FSUBD ·ALPHA130(SB), F0 + FSUBD F0,F4 + FMULD ·SCALE(SB), F0 + FMOVD ·ALPHA32(SB), F0 + FADDD F2,F0 + FSUBD ·ALPHA32(SB), F0 + FSUBD F0,F2 + FMOVD ·ALPHA64(SB), F0 + FADDD F4,F0 + FSUBD ·ALPHA64(SB), F0 + FSUBD F0,F4 + FMOVD ·ALPHA96(SB), F0 + FADDD F6,F0 + FSUBD ·ALPHA96(SB), F0 + FXCHD F0, F6 + FSUBD F6,F0 + FXCHD F0, F4 + FADDDP F0,F3 + FXCHD F0, F4 + FADDDP F0,F1 + FXCHD F0, F2 + FADDDP F0,F3 + FXCHD F0, F4 + FADDDP F0,F3 + FXCHD F0, F3 + FADDD ·HOFFSET0(SB), F0 + FXCHD F0, F3 + FADDD ·HOFFSET1(SB), F0 + FXCHD F0, F1 + FADDD ·HOFFSET2(SB), F0 + FXCHD F0, F2 + FADDD ·HOFFSET3(SB), F0 + FXCHD F0, F3 + FMOVDP F0, 104(SP) + FMOVDP F0, 112(SP) + FMOVDP F0, 120(SP) + FMOVDP F0, 128(SP) + MOVL 108(SP),DI + ANDL $63,DI + MOVL 116(SP),SI + ANDL $63,SI + MOVL 124(SP),DX + ANDL $63,DX + MOVL 132(SP),CX + ANDL $63,CX + MOVL 112(SP),R8 + ADDL DI,R8 + MOVQ R8,112(SP) + MOVL 120(SP),DI + ADCL SI,DI + MOVQ DI,120(SP) + MOVL 128(SP),DI + ADCL DX,DI + MOVQ DI,128(SP) + MOVL R10,DI + ADCL CX,DI + MOVQ DI,136(SP) + MOVQ $5,DI + MOVL 104(SP),SI + ADDL SI,DI + MOVQ DI,104(SP) + MOVL R10,DI + MOVQ 112(SP),DX + ADCL DX,DI + MOVQ DI,112(SP) + MOVL R10,DI + MOVQ 120(SP),CX + ADCL CX,DI + MOVQ DI,120(SP) + MOVL R10,DI + MOVQ 128(SP),R8 + ADCL R8,DI + MOVQ DI,128(SP) + MOVQ $0XFFFFFFFC,DI + MOVQ 136(SP),R9 + ADCL R9,DI + SARL $16,DI + MOVQ DI,R9 + XORL $0XFFFFFFFF,R9 + ANDQ DI,SI + MOVQ 104(SP),AX + ANDQ R9,AX + ORQ AX,SI + ANDQ DI,DX + MOVQ 112(SP),AX + ANDQ R9,AX + ORQ AX,DX + ANDQ DI,CX + MOVQ 120(SP),AX + ANDQ R9,AX + ORQ AX,CX + ANDQ DI,R8 + MOVQ 128(SP),DI + ANDQ R9,DI + ORQ DI,R8 + MOVQ 88(SP),DI + MOVQ 96(SP),R9 + ADDL 16(R9),SI + ADCL 20(R9),DX + ADCL 24(R9),CX + ADCL 28(R9),R8 + MOVL SI,0(DI) + MOVL DX,4(DI) + MOVL CX,8(DI) + MOVL R8,12(DI) + MOVQ 32(SP),R11 + MOVQ 40(SP),R12 + MOVQ 48(SP),R13 + MOVQ 56(SP),R14 + MOVQ 64(SP),R15 + MOVQ 72(SP),BX + MOVQ 80(SP),BP + MOVQ R11,SP + RET diff --git a/Godeps/_workspace/src/golang.org/x/crypto/poly1305/poly1305_test.go b/Godeps/_workspace/src/golang.org/x/crypto/poly1305/poly1305_test.go new file mode 100644 index 000000000..2c6d1bc98 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/poly1305/poly1305_test.go @@ -0,0 +1,74 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package poly1305 + +import ( + "bytes" + "testing" +) + +var testData = []struct { + in, k, correct []byte +}{ + { + []byte("Hello world!"), + []byte("this is 32-byte key for Poly1305"), + []byte{0xa6, 0xf7, 0x45, 0x00, 0x8f, 0x81, 0xc9, 0x16, 0xa2, 0x0d, 0xcc, 0x74, 0xee, 0xf2, 0xb2, 0xf0}, + }, + { + make([]byte, 32), + []byte("this is 32-byte key for Poly1305"), + []byte{0x49, 0xec, 0x78, 0x09, 0x0e, 0x48, 0x1e, 0xc6, 0xc2, 0x6b, 0x33, 0xb9, 0x1c, 0xcc, 0x03, 0x07}, + }, + { + make([]byte, 2007), + []byte("this is 32-byte key for Poly1305"), + []byte{0xda, 0x84, 0xbc, 0xab, 0x02, 0x67, 0x6c, 0x38, 0xcd, 0xb0, 0x15, 0x60, 0x42, 0x74, 0xc2, 0xaa}, + }, + { + make([]byte, 2007), + make([]byte, 32), + make([]byte, 16), + }, +} + +func TestSum(t *testing.T) { + var out [16]byte + var key [32]byte + + for i, v := range testData { + copy(key[:], v.k) + Sum(&out, v.in, &key) + if !bytes.Equal(out[:], v.correct) { + t.Errorf("%d: expected %x, got %x", i, v.correct, out[:]) + } + } +} + +func Benchmark1K(b *testing.B) { + b.StopTimer() + var out [16]byte + var key [32]byte + in := make([]byte, 1024) + b.SetBytes(int64(len(in))) + b.StartTimer() + + for i := 0; i < b.N; i++ { + Sum(&out, in, &key) + } +} + +func Benchmark64(b *testing.B) { + b.StopTimer() + var out [16]byte + var key [32]byte + in := make([]byte, 64) + b.SetBytes(int64(len(in))) + b.StartTimer() + + for i := 0; i < b.N; i++ { + Sum(&out, in, &key) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/poly1305/sum_amd64.go b/Godeps/_workspace/src/golang.org/x/crypto/poly1305/sum_amd64.go new file mode 100644 index 000000000..6775c703f --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/poly1305/sum_amd64.go @@ -0,0 +1,24 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build amd64,!gccgo,!appengine + +package poly1305 + +// This function is implemented in poly1305_amd64.s + +//go:noescape + +func poly1305(out *[16]byte, m *byte, mlen uint64, key *[32]byte) + +// Sum generates an authenticator for m using a one-time key and puts the +// 16-byte result into out. Authenticating two different messages with the same +// key allows an attacker to forge messages at will. +func Sum(out *[16]byte, m []byte, key *[32]byte) { + var mPtr *byte + if len(m) > 0 { + mPtr = &m[0] + } + poly1305(out, mPtr, uint64(len(m)), key) +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/poly1305/sum_ref.go b/Godeps/_workspace/src/golang.org/x/crypto/poly1305/sum_ref.go new file mode 100644 index 000000000..92d38018d --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/poly1305/sum_ref.go @@ -0,0 +1,1531 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !amd64 gccgo appengine + +package poly1305 + +// Based on original, public domain implementation from NaCl by D. J. +// Bernstein. + +import "math" + +const ( + alpham80 = 0.00000000558793544769287109375 + alpham48 = 24.0 + alpham16 = 103079215104.0 + alpha0 = 6755399441055744.0 + alpha18 = 1770887431076116955136.0 + alpha32 = 29014219670751100192948224.0 + alpha50 = 7605903601369376408980219232256.0 + alpha64 = 124615124604835863084731911901282304.0 + alpha82 = 32667107224410092492483962313449748299776.0 + alpha96 = 535217884764734955396857238543560676143529984.0 + alpha112 = 35076039295941670036888435985190792471742381031424.0 + alpha130 = 9194973245195333150150082162901855101712434733101613056.0 + scale = 0.0000000000000000000000000000000000000036734198463196484624023016788195177431833298649127735047148490821200539357960224151611328125 + offset0 = 6755408030990331.0 + offset1 = 29014256564239239022116864.0 + offset2 = 124615283061160854719918951570079744.0 + offset3 = 535219245894202480694386063513315216128475136.0 +) + +// Sum generates an authenticator for m using a one-time key and puts the +// 16-byte result into out. Authenticating two different messages with the same +// key allows an attacker to forge messages at will. +func Sum(out *[16]byte, m []byte, key *[32]byte) { + r := key + s := key[16:] + var ( + y7 float64 + y6 float64 + y1 float64 + y0 float64 + y5 float64 + y4 float64 + x7 float64 + x6 float64 + x1 float64 + x0 float64 + y3 float64 + y2 float64 + x5 float64 + r3lowx0 float64 + x4 float64 + r0lowx6 float64 + x3 float64 + r3highx0 float64 + x2 float64 + r0highx6 float64 + r0lowx0 float64 + sr1lowx6 float64 + r0highx0 float64 + sr1highx6 float64 + sr3low float64 + r1lowx0 float64 + sr2lowx6 float64 + r1highx0 float64 + sr2highx6 float64 + r2lowx0 float64 + sr3lowx6 float64 + r2highx0 float64 + sr3highx6 float64 + r1highx4 float64 + r1lowx4 float64 + r0highx4 float64 + r0lowx4 float64 + sr3highx4 float64 + sr3lowx4 float64 + sr2highx4 float64 + sr2lowx4 float64 + r0lowx2 float64 + r0highx2 float64 + r1lowx2 float64 + r1highx2 float64 + r2lowx2 float64 + r2highx2 float64 + sr3lowx2 float64 + sr3highx2 float64 + z0 float64 + z1 float64 + z2 float64 + z3 float64 + m0 int64 + m1 int64 + m2 int64 + m3 int64 + m00 uint32 + m01 uint32 + m02 uint32 + m03 uint32 + m10 uint32 + m11 uint32 + m12 uint32 + m13 uint32 + m20 uint32 + m21 uint32 + m22 uint32 + m23 uint32 + m30 uint32 + m31 uint32 + m32 uint32 + m33 uint64 + lbelow2 int32 + lbelow3 int32 + lbelow4 int32 + lbelow5 int32 + lbelow6 int32 + lbelow7 int32 + lbelow8 int32 + lbelow9 int32 + lbelow10 int32 + lbelow11 int32 + lbelow12 int32 + lbelow13 int32 + lbelow14 int32 + lbelow15 int32 + s00 uint32 + s01 uint32 + s02 uint32 + s03 uint32 + s10 uint32 + s11 uint32 + s12 uint32 + s13 uint32 + s20 uint32 + s21 uint32 + s22 uint32 + s23 uint32 + s30 uint32 + s31 uint32 + s32 uint32 + s33 uint32 + bits32 uint64 + f uint64 + f0 uint64 + f1 uint64 + f2 uint64 + f3 uint64 + f4 uint64 + g uint64 + g0 uint64 + g1 uint64 + g2 uint64 + g3 uint64 + g4 uint64 + ) + + var p int32 + + l := int32(len(m)) + + r00 := uint32(r[0]) + + r01 := uint32(r[1]) + + r02 := uint32(r[2]) + r0 := int64(2151) + + r03 := uint32(r[3]) + r03 &= 15 + r0 <<= 51 + + r10 := uint32(r[4]) + r10 &= 252 + r01 <<= 8 + r0 += int64(r00) + + r11 := uint32(r[5]) + r02 <<= 16 + r0 += int64(r01) + + r12 := uint32(r[6]) + r03 <<= 24 + r0 += int64(r02) + + r13 := uint32(r[7]) + r13 &= 15 + r1 := int64(2215) + r0 += int64(r03) + + d0 := r0 + r1 <<= 51 + r2 := int64(2279) + + r20 := uint32(r[8]) + r20 &= 252 + r11 <<= 8 + r1 += int64(r10) + + r21 := uint32(r[9]) + r12 <<= 16 + r1 += int64(r11) + + r22 := uint32(r[10]) + r13 <<= 24 + r1 += int64(r12) + + r23 := uint32(r[11]) + r23 &= 15 + r2 <<= 51 + r1 += int64(r13) + + d1 := r1 + r21 <<= 8 + r2 += int64(r20) + + r30 := uint32(r[12]) + r30 &= 252 + r22 <<= 16 + r2 += int64(r21) + + r31 := uint32(r[13]) + r23 <<= 24 + r2 += int64(r22) + + r32 := uint32(r[14]) + r2 += int64(r23) + r3 := int64(2343) + + d2 := r2 + r3 <<= 51 + + r33 := uint32(r[15]) + r33 &= 15 + r31 <<= 8 + r3 += int64(r30) + + r32 <<= 16 + r3 += int64(r31) + + r33 <<= 24 + r3 += int64(r32) + + r3 += int64(r33) + h0 := alpha32 - alpha32 + + d3 := r3 + h1 := alpha32 - alpha32 + + h2 := alpha32 - alpha32 + + h3 := alpha32 - alpha32 + + h4 := alpha32 - alpha32 + + r0low := math.Float64frombits(uint64(d0)) + h5 := alpha32 - alpha32 + + r1low := math.Float64frombits(uint64(d1)) + h6 := alpha32 - alpha32 + + r2low := math.Float64frombits(uint64(d2)) + h7 := alpha32 - alpha32 + + r0low -= alpha0 + + r1low -= alpha32 + + r2low -= alpha64 + + r0high := r0low + alpha18 + + r3low := math.Float64frombits(uint64(d3)) + + r1high := r1low + alpha50 + sr1low := scale * r1low + + r2high := r2low + alpha82 + sr2low := scale * r2low + + r0high -= alpha18 + r0high_stack := r0high + + r3low -= alpha96 + + r1high -= alpha50 + r1high_stack := r1high + + sr1high := sr1low + alpham80 + + r0low -= r0high + + r2high -= alpha82 + sr3low = scale * r3low + + sr2high := sr2low + alpham48 + + r1low -= r1high + r1low_stack := r1low + + sr1high -= alpham80 + sr1high_stack := sr1high + + r2low -= r2high + r2low_stack := r2low + + sr2high -= alpham48 + sr2high_stack := sr2high + + r3high := r3low + alpha112 + r0low_stack := r0low + + sr1low -= sr1high + sr1low_stack := sr1low + + sr3high := sr3low + alpham16 + r2high_stack := r2high + + sr2low -= sr2high + sr2low_stack := sr2low + + r3high -= alpha112 + r3high_stack := r3high + + sr3high -= alpham16 + sr3high_stack := sr3high + + r3low -= r3high + r3low_stack := r3low + + sr3low -= sr3high + sr3low_stack := sr3low + + if l < 16 { + goto addatmost15bytes + } + + m00 = uint32(m[p+0]) + m0 = 2151 + + m0 <<= 51 + m1 = 2215 + m01 = uint32(m[p+1]) + + m1 <<= 51 + m2 = 2279 + m02 = uint32(m[p+2]) + + m2 <<= 51 + m3 = 2343 + m03 = uint32(m[p+3]) + + m10 = uint32(m[p+4]) + m01 <<= 8 + m0 += int64(m00) + + m11 = uint32(m[p+5]) + m02 <<= 16 + m0 += int64(m01) + + m12 = uint32(m[p+6]) + m03 <<= 24 + m0 += int64(m02) + + m13 = uint32(m[p+7]) + m3 <<= 51 + m0 += int64(m03) + + m20 = uint32(m[p+8]) + m11 <<= 8 + m1 += int64(m10) + + m21 = uint32(m[p+9]) + m12 <<= 16 + m1 += int64(m11) + + m22 = uint32(m[p+10]) + m13 <<= 24 + m1 += int64(m12) + + m23 = uint32(m[p+11]) + m1 += int64(m13) + + m30 = uint32(m[p+12]) + m21 <<= 8 + m2 += int64(m20) + + m31 = uint32(m[p+13]) + m22 <<= 16 + m2 += int64(m21) + + m32 = uint32(m[p+14]) + m23 <<= 24 + m2 += int64(m22) + + m33 = uint64(m[p+15]) + m2 += int64(m23) + + d0 = m0 + m31 <<= 8 + m3 += int64(m30) + + d1 = m1 + m32 <<= 16 + m3 += int64(m31) + + d2 = m2 + m33 += 256 + + m33 <<= 24 + m3 += int64(m32) + + m3 += int64(m33) + d3 = m3 + + p += 16 + l -= 16 + + z0 = math.Float64frombits(uint64(d0)) + + z1 = math.Float64frombits(uint64(d1)) + + z2 = math.Float64frombits(uint64(d2)) + + z3 = math.Float64frombits(uint64(d3)) + + z0 -= alpha0 + + z1 -= alpha32 + + z2 -= alpha64 + + z3 -= alpha96 + + h0 += z0 + + h1 += z1 + + h3 += z2 + + h5 += z3 + + if l < 16 { + goto multiplyaddatmost15bytes + } + +multiplyaddatleast16bytes: + + m2 = 2279 + m20 = uint32(m[p+8]) + y7 = h7 + alpha130 + + m2 <<= 51 + m3 = 2343 + m21 = uint32(m[p+9]) + y6 = h6 + alpha130 + + m3 <<= 51 + m0 = 2151 + m22 = uint32(m[p+10]) + y1 = h1 + alpha32 + + m0 <<= 51 + m1 = 2215 + m23 = uint32(m[p+11]) + y0 = h0 + alpha32 + + m1 <<= 51 + m30 = uint32(m[p+12]) + y7 -= alpha130 + + m21 <<= 8 + m2 += int64(m20) + m31 = uint32(m[p+13]) + y6 -= alpha130 + + m22 <<= 16 + m2 += int64(m21) + m32 = uint32(m[p+14]) + y1 -= alpha32 + + m23 <<= 24 + m2 += int64(m22) + m33 = uint64(m[p+15]) + y0 -= alpha32 + + m2 += int64(m23) + m00 = uint32(m[p+0]) + y5 = h5 + alpha96 + + m31 <<= 8 + m3 += int64(m30) + m01 = uint32(m[p+1]) + y4 = h4 + alpha96 + + m32 <<= 16 + m02 = uint32(m[p+2]) + x7 = h7 - y7 + y7 *= scale + + m33 += 256 + m03 = uint32(m[p+3]) + x6 = h6 - y6 + y6 *= scale + + m33 <<= 24 + m3 += int64(m31) + m10 = uint32(m[p+4]) + x1 = h1 - y1 + + m01 <<= 8 + m3 += int64(m32) + m11 = uint32(m[p+5]) + x0 = h0 - y0 + + m3 += int64(m33) + m0 += int64(m00) + m12 = uint32(m[p+6]) + y5 -= alpha96 + + m02 <<= 16 + m0 += int64(m01) + m13 = uint32(m[p+7]) + y4 -= alpha96 + + m03 <<= 24 + m0 += int64(m02) + d2 = m2 + x1 += y7 + + m0 += int64(m03) + d3 = m3 + x0 += y6 + + m11 <<= 8 + m1 += int64(m10) + d0 = m0 + x7 += y5 + + m12 <<= 16 + m1 += int64(m11) + x6 += y4 + + m13 <<= 24 + m1 += int64(m12) + y3 = h3 + alpha64 + + m1 += int64(m13) + d1 = m1 + y2 = h2 + alpha64 + + x0 += x1 + + x6 += x7 + + y3 -= alpha64 + r3low = r3low_stack + + y2 -= alpha64 + r0low = r0low_stack + + x5 = h5 - y5 + r3lowx0 = r3low * x0 + r3high = r3high_stack + + x4 = h4 - y4 + r0lowx6 = r0low * x6 + r0high = r0high_stack + + x3 = h3 - y3 + r3highx0 = r3high * x0 + sr1low = sr1low_stack + + x2 = h2 - y2 + r0highx6 = r0high * x6 + sr1high = sr1high_stack + + x5 += y3 + r0lowx0 = r0low * x0 + r1low = r1low_stack + + h6 = r3lowx0 + r0lowx6 + sr1lowx6 = sr1low * x6 + r1high = r1high_stack + + x4 += y2 + r0highx0 = r0high * x0 + sr2low = sr2low_stack + + h7 = r3highx0 + r0highx6 + sr1highx6 = sr1high * x6 + sr2high = sr2high_stack + + x3 += y1 + r1lowx0 = r1low * x0 + r2low = r2low_stack + + h0 = r0lowx0 + sr1lowx6 + sr2lowx6 = sr2low * x6 + r2high = r2high_stack + + x2 += y0 + r1highx0 = r1high * x0 + sr3low = sr3low_stack + + h1 = r0highx0 + sr1highx6 + sr2highx6 = sr2high * x6 + sr3high = sr3high_stack + + x4 += x5 + r2lowx0 = r2low * x0 + z2 = math.Float64frombits(uint64(d2)) + + h2 = r1lowx0 + sr2lowx6 + sr3lowx6 = sr3low * x6 + + x2 += x3 + r2highx0 = r2high * x0 + z3 = math.Float64frombits(uint64(d3)) + + h3 = r1highx0 + sr2highx6 + sr3highx6 = sr3high * x6 + + r1highx4 = r1high * x4 + z2 -= alpha64 + + h4 = r2lowx0 + sr3lowx6 + r1lowx4 = r1low * x4 + + r0highx4 = r0high * x4 + z3 -= alpha96 + + h5 = r2highx0 + sr3highx6 + r0lowx4 = r0low * x4 + + h7 += r1highx4 + sr3highx4 = sr3high * x4 + + h6 += r1lowx4 + sr3lowx4 = sr3low * x4 + + h5 += r0highx4 + sr2highx4 = sr2high * x4 + + h4 += r0lowx4 + sr2lowx4 = sr2low * x4 + + h3 += sr3highx4 + r0lowx2 = r0low * x2 + + h2 += sr3lowx4 + r0highx2 = r0high * x2 + + h1 += sr2highx4 + r1lowx2 = r1low * x2 + + h0 += sr2lowx4 + r1highx2 = r1high * x2 + + h2 += r0lowx2 + r2lowx2 = r2low * x2 + + h3 += r0highx2 + r2highx2 = r2high * x2 + + h4 += r1lowx2 + sr3lowx2 = sr3low * x2 + + h5 += r1highx2 + sr3highx2 = sr3high * x2 + + p += 16 + l -= 16 + h6 += r2lowx2 + + h7 += r2highx2 + + z1 = math.Float64frombits(uint64(d1)) + h0 += sr3lowx2 + + z0 = math.Float64frombits(uint64(d0)) + h1 += sr3highx2 + + z1 -= alpha32 + + z0 -= alpha0 + + h5 += z3 + + h3 += z2 + + h1 += z1 + + h0 += z0 + + if l >= 16 { + goto multiplyaddatleast16bytes + } + +multiplyaddatmost15bytes: + + y7 = h7 + alpha130 + + y6 = h6 + alpha130 + + y1 = h1 + alpha32 + + y0 = h0 + alpha32 + + y7 -= alpha130 + + y6 -= alpha130 + + y1 -= alpha32 + + y0 -= alpha32 + + y5 = h5 + alpha96 + + y4 = h4 + alpha96 + + x7 = h7 - y7 + y7 *= scale + + x6 = h6 - y6 + y6 *= scale + + x1 = h1 - y1 + + x0 = h0 - y0 + + y5 -= alpha96 + + y4 -= alpha96 + + x1 += y7 + + x0 += y6 + + x7 += y5 + + x6 += y4 + + y3 = h3 + alpha64 + + y2 = h2 + alpha64 + + x0 += x1 + + x6 += x7 + + y3 -= alpha64 + r3low = r3low_stack + + y2 -= alpha64 + r0low = r0low_stack + + x5 = h5 - y5 + r3lowx0 = r3low * x0 + r3high = r3high_stack + + x4 = h4 - y4 + r0lowx6 = r0low * x6 + r0high = r0high_stack + + x3 = h3 - y3 + r3highx0 = r3high * x0 + sr1low = sr1low_stack + + x2 = h2 - y2 + r0highx6 = r0high * x6 + sr1high = sr1high_stack + + x5 += y3 + r0lowx0 = r0low * x0 + r1low = r1low_stack + + h6 = r3lowx0 + r0lowx6 + sr1lowx6 = sr1low * x6 + r1high = r1high_stack + + x4 += y2 + r0highx0 = r0high * x0 + sr2low = sr2low_stack + + h7 = r3highx0 + r0highx6 + sr1highx6 = sr1high * x6 + sr2high = sr2high_stack + + x3 += y1 + r1lowx0 = r1low * x0 + r2low = r2low_stack + + h0 = r0lowx0 + sr1lowx6 + sr2lowx6 = sr2low * x6 + r2high = r2high_stack + + x2 += y0 + r1highx0 = r1high * x0 + sr3low = sr3low_stack + + h1 = r0highx0 + sr1highx6 + sr2highx6 = sr2high * x6 + sr3high = sr3high_stack + + x4 += x5 + r2lowx0 = r2low * x0 + + h2 = r1lowx0 + sr2lowx6 + sr3lowx6 = sr3low * x6 + + x2 += x3 + r2highx0 = r2high * x0 + + h3 = r1highx0 + sr2highx6 + sr3highx6 = sr3high * x6 + + r1highx4 = r1high * x4 + + h4 = r2lowx0 + sr3lowx6 + r1lowx4 = r1low * x4 + + r0highx4 = r0high * x4 + + h5 = r2highx0 + sr3highx6 + r0lowx4 = r0low * x4 + + h7 += r1highx4 + sr3highx4 = sr3high * x4 + + h6 += r1lowx4 + sr3lowx4 = sr3low * x4 + + h5 += r0highx4 + sr2highx4 = sr2high * x4 + + h4 += r0lowx4 + sr2lowx4 = sr2low * x4 + + h3 += sr3highx4 + r0lowx2 = r0low * x2 + + h2 += sr3lowx4 + r0highx2 = r0high * x2 + + h1 += sr2highx4 + r1lowx2 = r1low * x2 + + h0 += sr2lowx4 + r1highx2 = r1high * x2 + + h2 += r0lowx2 + r2lowx2 = r2low * x2 + + h3 += r0highx2 + r2highx2 = r2high * x2 + + h4 += r1lowx2 + sr3lowx2 = sr3low * x2 + + h5 += r1highx2 + sr3highx2 = sr3high * x2 + + h6 += r2lowx2 + + h7 += r2highx2 + + h0 += sr3lowx2 + + h1 += sr3highx2 + +addatmost15bytes: + + if l == 0 { + goto nomorebytes + } + + lbelow2 = l - 2 + + lbelow3 = l - 3 + + lbelow2 >>= 31 + lbelow4 = l - 4 + + m00 = uint32(m[p+0]) + lbelow3 >>= 31 + p += lbelow2 + + m01 = uint32(m[p+1]) + lbelow4 >>= 31 + p += lbelow3 + + m02 = uint32(m[p+2]) + p += lbelow4 + m0 = 2151 + + m03 = uint32(m[p+3]) + m0 <<= 51 + m1 = 2215 + + m0 += int64(m00) + m01 &^= uint32(lbelow2) + + m02 &^= uint32(lbelow3) + m01 -= uint32(lbelow2) + + m01 <<= 8 + m03 &^= uint32(lbelow4) + + m0 += int64(m01) + lbelow2 -= lbelow3 + + m02 += uint32(lbelow2) + lbelow3 -= lbelow4 + + m02 <<= 16 + m03 += uint32(lbelow3) + + m03 <<= 24 + m0 += int64(m02) + + m0 += int64(m03) + lbelow5 = l - 5 + + lbelow6 = l - 6 + lbelow7 = l - 7 + + lbelow5 >>= 31 + lbelow8 = l - 8 + + lbelow6 >>= 31 + p += lbelow5 + + m10 = uint32(m[p+4]) + lbelow7 >>= 31 + p += lbelow6 + + m11 = uint32(m[p+5]) + lbelow8 >>= 31 + p += lbelow7 + + m12 = uint32(m[p+6]) + m1 <<= 51 + p += lbelow8 + + m13 = uint32(m[p+7]) + m10 &^= uint32(lbelow5) + lbelow4 -= lbelow5 + + m10 += uint32(lbelow4) + lbelow5 -= lbelow6 + + m11 &^= uint32(lbelow6) + m11 += uint32(lbelow5) + + m11 <<= 8 + m1 += int64(m10) + + m1 += int64(m11) + m12 &^= uint32(lbelow7) + + lbelow6 -= lbelow7 + m13 &^= uint32(lbelow8) + + m12 += uint32(lbelow6) + lbelow7 -= lbelow8 + + m12 <<= 16 + m13 += uint32(lbelow7) + + m13 <<= 24 + m1 += int64(m12) + + m1 += int64(m13) + m2 = 2279 + + lbelow9 = l - 9 + m3 = 2343 + + lbelow10 = l - 10 + lbelow11 = l - 11 + + lbelow9 >>= 31 + lbelow12 = l - 12 + + lbelow10 >>= 31 + p += lbelow9 + + m20 = uint32(m[p+8]) + lbelow11 >>= 31 + p += lbelow10 + + m21 = uint32(m[p+9]) + lbelow12 >>= 31 + p += lbelow11 + + m22 = uint32(m[p+10]) + m2 <<= 51 + p += lbelow12 + + m23 = uint32(m[p+11]) + m20 &^= uint32(lbelow9) + lbelow8 -= lbelow9 + + m20 += uint32(lbelow8) + lbelow9 -= lbelow10 + + m21 &^= uint32(lbelow10) + m21 += uint32(lbelow9) + + m21 <<= 8 + m2 += int64(m20) + + m2 += int64(m21) + m22 &^= uint32(lbelow11) + + lbelow10 -= lbelow11 + m23 &^= uint32(lbelow12) + + m22 += uint32(lbelow10) + lbelow11 -= lbelow12 + + m22 <<= 16 + m23 += uint32(lbelow11) + + m23 <<= 24 + m2 += int64(m22) + + m3 <<= 51 + lbelow13 = l - 13 + + lbelow13 >>= 31 + lbelow14 = l - 14 + + lbelow14 >>= 31 + p += lbelow13 + lbelow15 = l - 15 + + m30 = uint32(m[p+12]) + lbelow15 >>= 31 + p += lbelow14 + + m31 = uint32(m[p+13]) + p += lbelow15 + m2 += int64(m23) + + m32 = uint32(m[p+14]) + m30 &^= uint32(lbelow13) + lbelow12 -= lbelow13 + + m30 += uint32(lbelow12) + lbelow13 -= lbelow14 + + m3 += int64(m30) + m31 &^= uint32(lbelow14) + + m31 += uint32(lbelow13) + m32 &^= uint32(lbelow15) + + m31 <<= 8 + lbelow14 -= lbelow15 + + m3 += int64(m31) + m32 += uint32(lbelow14) + d0 = m0 + + m32 <<= 16 + m33 = uint64(lbelow15 + 1) + d1 = m1 + + m33 <<= 24 + m3 += int64(m32) + d2 = m2 + + m3 += int64(m33) + d3 = m3 + + z3 = math.Float64frombits(uint64(d3)) + + z2 = math.Float64frombits(uint64(d2)) + + z1 = math.Float64frombits(uint64(d1)) + + z0 = math.Float64frombits(uint64(d0)) + + z3 -= alpha96 + + z2 -= alpha64 + + z1 -= alpha32 + + z0 -= alpha0 + + h5 += z3 + + h3 += z2 + + h1 += z1 + + h0 += z0 + + y7 = h7 + alpha130 + + y6 = h6 + alpha130 + + y1 = h1 + alpha32 + + y0 = h0 + alpha32 + + y7 -= alpha130 + + y6 -= alpha130 + + y1 -= alpha32 + + y0 -= alpha32 + + y5 = h5 + alpha96 + + y4 = h4 + alpha96 + + x7 = h7 - y7 + y7 *= scale + + x6 = h6 - y6 + y6 *= scale + + x1 = h1 - y1 + + x0 = h0 - y0 + + y5 -= alpha96 + + y4 -= alpha96 + + x1 += y7 + + x0 += y6 + + x7 += y5 + + x6 += y4 + + y3 = h3 + alpha64 + + y2 = h2 + alpha64 + + x0 += x1 + + x6 += x7 + + y3 -= alpha64 + r3low = r3low_stack + + y2 -= alpha64 + r0low = r0low_stack + + x5 = h5 - y5 + r3lowx0 = r3low * x0 + r3high = r3high_stack + + x4 = h4 - y4 + r0lowx6 = r0low * x6 + r0high = r0high_stack + + x3 = h3 - y3 + r3highx0 = r3high * x0 + sr1low = sr1low_stack + + x2 = h2 - y2 + r0highx6 = r0high * x6 + sr1high = sr1high_stack + + x5 += y3 + r0lowx0 = r0low * x0 + r1low = r1low_stack + + h6 = r3lowx0 + r0lowx6 + sr1lowx6 = sr1low * x6 + r1high = r1high_stack + + x4 += y2 + r0highx0 = r0high * x0 + sr2low = sr2low_stack + + h7 = r3highx0 + r0highx6 + sr1highx6 = sr1high * x6 + sr2high = sr2high_stack + + x3 += y1 + r1lowx0 = r1low * x0 + r2low = r2low_stack + + h0 = r0lowx0 + sr1lowx6 + sr2lowx6 = sr2low * x6 + r2high = r2high_stack + + x2 += y0 + r1highx0 = r1high * x0 + sr3low = sr3low_stack + + h1 = r0highx0 + sr1highx6 + sr2highx6 = sr2high * x6 + sr3high = sr3high_stack + + x4 += x5 + r2lowx0 = r2low * x0 + + h2 = r1lowx0 + sr2lowx6 + sr3lowx6 = sr3low * x6 + + x2 += x3 + r2highx0 = r2high * x0 + + h3 = r1highx0 + sr2highx6 + sr3highx6 = sr3high * x6 + + r1highx4 = r1high * x4 + + h4 = r2lowx0 + sr3lowx6 + r1lowx4 = r1low * x4 + + r0highx4 = r0high * x4 + + h5 = r2highx0 + sr3highx6 + r0lowx4 = r0low * x4 + + h7 += r1highx4 + sr3highx4 = sr3high * x4 + + h6 += r1lowx4 + sr3lowx4 = sr3low * x4 + + h5 += r0highx4 + sr2highx4 = sr2high * x4 + + h4 += r0lowx4 + sr2lowx4 = sr2low * x4 + + h3 += sr3highx4 + r0lowx2 = r0low * x2 + + h2 += sr3lowx4 + r0highx2 = r0high * x2 + + h1 += sr2highx4 + r1lowx2 = r1low * x2 + + h0 += sr2lowx4 + r1highx2 = r1high * x2 + + h2 += r0lowx2 + r2lowx2 = r2low * x2 + + h3 += r0highx2 + r2highx2 = r2high * x2 + + h4 += r1lowx2 + sr3lowx2 = sr3low * x2 + + h5 += r1highx2 + sr3highx2 = sr3high * x2 + + h6 += r2lowx2 + + h7 += r2highx2 + + h0 += sr3lowx2 + + h1 += sr3highx2 + +nomorebytes: + + y7 = h7 + alpha130 + + y0 = h0 + alpha32 + + y1 = h1 + alpha32 + + y2 = h2 + alpha64 + + y7 -= alpha130 + + y3 = h3 + alpha64 + + y4 = h4 + alpha96 + + y5 = h5 + alpha96 + + x7 = h7 - y7 + y7 *= scale + + y0 -= alpha32 + + y1 -= alpha32 + + y2 -= alpha64 + + h6 += x7 + + y3 -= alpha64 + + y4 -= alpha96 + + y5 -= alpha96 + + y6 = h6 + alpha130 + + x0 = h0 - y0 + + x1 = h1 - y1 + + x2 = h2 - y2 + + y6 -= alpha130 + + x0 += y7 + + x3 = h3 - y3 + + x4 = h4 - y4 + + x5 = h5 - y5 + + x6 = h6 - y6 + + y6 *= scale + + x2 += y0 + + x3 += y1 + + x4 += y2 + + x0 += y6 + + x5 += y3 + + x6 += y4 + + x2 += x3 + + x0 += x1 + + x4 += x5 + + x6 += y5 + + x2 += offset1 + d1 = int64(math.Float64bits(x2)) + + x0 += offset0 + d0 = int64(math.Float64bits(x0)) + + x4 += offset2 + d2 = int64(math.Float64bits(x4)) + + x6 += offset3 + d3 = int64(math.Float64bits(x6)) + + f0 = uint64(d0) + + f1 = uint64(d1) + bits32 = math.MaxUint64 + + f2 = uint64(d2) + bits32 >>= 32 + + f3 = uint64(d3) + f = f0 >> 32 + + f0 &= bits32 + f &= 255 + + f1 += f + g0 = f0 + 5 + + g = g0 >> 32 + g0 &= bits32 + + f = f1 >> 32 + f1 &= bits32 + + f &= 255 + g1 = f1 + g + + g = g1 >> 32 + f2 += f + + f = f2 >> 32 + g1 &= bits32 + + f2 &= bits32 + f &= 255 + + f3 += f + g2 = f2 + g + + g = g2 >> 32 + g2 &= bits32 + + f4 = f3 >> 32 + f3 &= bits32 + + f4 &= 255 + g3 = f3 + g + + g = g3 >> 32 + g3 &= bits32 + + g4 = f4 + g + + g4 = g4 - 4 + s00 = uint32(s[0]) + + f = uint64(int64(g4) >> 63) + s01 = uint32(s[1]) + + f0 &= f + g0 &^= f + s02 = uint32(s[2]) + + f1 &= f + f0 |= g0 + s03 = uint32(s[3]) + + g1 &^= f + f2 &= f + s10 = uint32(s[4]) + + f3 &= f + g2 &^= f + s11 = uint32(s[5]) + + g3 &^= f + f1 |= g1 + s12 = uint32(s[6]) + + f2 |= g2 + f3 |= g3 + s13 = uint32(s[7]) + + s01 <<= 8 + f0 += uint64(s00) + s20 = uint32(s[8]) + + s02 <<= 16 + f0 += uint64(s01) + s21 = uint32(s[9]) + + s03 <<= 24 + f0 += uint64(s02) + s22 = uint32(s[10]) + + s11 <<= 8 + f1 += uint64(s10) + s23 = uint32(s[11]) + + s12 <<= 16 + f1 += uint64(s11) + s30 = uint32(s[12]) + + s13 <<= 24 + f1 += uint64(s12) + s31 = uint32(s[13]) + + f0 += uint64(s03) + f1 += uint64(s13) + s32 = uint32(s[14]) + + s21 <<= 8 + f2 += uint64(s20) + s33 = uint32(s[15]) + + s22 <<= 16 + f2 += uint64(s21) + + s23 <<= 24 + f2 += uint64(s22) + + s31 <<= 8 + f3 += uint64(s30) + + s32 <<= 16 + f3 += uint64(s31) + + s33 <<= 24 + f3 += uint64(s32) + + f2 += uint64(s23) + f3 += uint64(s33) + + out[0] = byte(f0) + f0 >>= 8 + out[1] = byte(f0) + f0 >>= 8 + out[2] = byte(f0) + f0 >>= 8 + out[3] = byte(f0) + f0 >>= 8 + f1 += f0 + + out[4] = byte(f1) + f1 >>= 8 + out[5] = byte(f1) + f1 >>= 8 + out[6] = byte(f1) + f1 >>= 8 + out[7] = byte(f1) + f1 >>= 8 + f2 += f1 + + out[8] = byte(f2) + f2 >>= 8 + out[9] = byte(f2) + f2 >>= 8 + out[10] = byte(f2) + f2 >>= 8 + out[11] = byte(f2) + f2 >>= 8 + f3 += f2 + + out[12] = byte(f3) + f3 >>= 8 + out[13] = byte(f3) + f3 >>= 8 + out[14] = byte(f3) + f3 >>= 8 + out[15] = byte(f3) +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/scrypt/scrypt.go b/Godeps/_workspace/src/golang.org/x/crypto/scrypt/scrypt.go new file mode 100644 index 000000000..dc0124b1f --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/scrypt/scrypt.go @@ -0,0 +1,243 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package scrypt implements the scrypt key derivation function as defined in +// Colin Percival's paper "Stronger Key Derivation via Sequential Memory-Hard +// Functions" (http://www.tarsnap.com/scrypt/scrypt.pdf). +package scrypt // import "golang.org/x/crypto/scrypt" + +import ( + "crypto/sha256" + "errors" + + "golang.org/x/crypto/pbkdf2" +) + +const maxInt = int(^uint(0) >> 1) + +// blockCopy copies n numbers from src into dst. +func blockCopy(dst, src []uint32, n int) { + copy(dst, src[:n]) +} + +// blockXOR XORs numbers from dst with n numbers from src. +func blockXOR(dst, src []uint32, n int) { + for i, v := range src[:n] { + dst[i] ^= v + } +} + +// salsaXOR applies Salsa20/8 to the XOR of 16 numbers from tmp and in, +// and puts the result into both both tmp and out. +func salsaXOR(tmp *[16]uint32, in, out []uint32) { + w0 := tmp[0] ^ in[0] + w1 := tmp[1] ^ in[1] + w2 := tmp[2] ^ in[2] + w3 := tmp[3] ^ in[3] + w4 := tmp[4] ^ in[4] + w5 := tmp[5] ^ in[5] + w6 := tmp[6] ^ in[6] + w7 := tmp[7] ^ in[7] + w8 := tmp[8] ^ in[8] + w9 := tmp[9] ^ in[9] + w10 := tmp[10] ^ in[10] + w11 := tmp[11] ^ in[11] + w12 := tmp[12] ^ in[12] + w13 := tmp[13] ^ in[13] + w14 := tmp[14] ^ in[14] + w15 := tmp[15] ^ in[15] + + x0, x1, x2, x3, x4, x5, x6, x7, x8 := w0, w1, w2, w3, w4, w5, w6, w7, w8 + x9, x10, x11, x12, x13, x14, x15 := w9, w10, w11, w12, w13, w14, w15 + + for i := 0; i < 8; i += 2 { + u := x0 + x12 + x4 ^= u<<7 | u>>(32-7) + u = x4 + x0 + x8 ^= u<<9 | u>>(32-9) + u = x8 + x4 + x12 ^= u<<13 | u>>(32-13) + u = x12 + x8 + x0 ^= u<<18 | u>>(32-18) + + u = x5 + x1 + x9 ^= u<<7 | u>>(32-7) + u = x9 + x5 + x13 ^= u<<9 | u>>(32-9) + u = x13 + x9 + x1 ^= u<<13 | u>>(32-13) + u = x1 + x13 + x5 ^= u<<18 | u>>(32-18) + + u = x10 + x6 + x14 ^= u<<7 | u>>(32-7) + u = x14 + x10 + x2 ^= u<<9 | u>>(32-9) + u = x2 + x14 + x6 ^= u<<13 | u>>(32-13) + u = x6 + x2 + x10 ^= u<<18 | u>>(32-18) + + u = x15 + x11 + x3 ^= u<<7 | u>>(32-7) + u = x3 + x15 + x7 ^= u<<9 | u>>(32-9) + u = x7 + x3 + x11 ^= u<<13 | u>>(32-13) + u = x11 + x7 + x15 ^= u<<18 | u>>(32-18) + + u = x0 + x3 + x1 ^= u<<7 | u>>(32-7) + u = x1 + x0 + x2 ^= u<<9 | u>>(32-9) + u = x2 + x1 + x3 ^= u<<13 | u>>(32-13) + u = x3 + x2 + x0 ^= u<<18 | u>>(32-18) + + u = x5 + x4 + x6 ^= u<<7 | u>>(32-7) + u = x6 + x5 + x7 ^= u<<9 | u>>(32-9) + u = x7 + x6 + x4 ^= u<<13 | u>>(32-13) + u = x4 + x7 + x5 ^= u<<18 | u>>(32-18) + + u = x10 + x9 + x11 ^= u<<7 | u>>(32-7) + u = x11 + x10 + x8 ^= u<<9 | u>>(32-9) + u = x8 + x11 + x9 ^= u<<13 | u>>(32-13) + u = x9 + x8 + x10 ^= u<<18 | u>>(32-18) + + u = x15 + x14 + x12 ^= u<<7 | u>>(32-7) + u = x12 + x15 + x13 ^= u<<9 | u>>(32-9) + u = x13 + x12 + x14 ^= u<<13 | u>>(32-13) + u = x14 + x13 + x15 ^= u<<18 | u>>(32-18) + } + x0 += w0 + x1 += w1 + x2 += w2 + x3 += w3 + x4 += w4 + x5 += w5 + x6 += w6 + x7 += w7 + x8 += w8 + x9 += w9 + x10 += w10 + x11 += w11 + x12 += w12 + x13 += w13 + x14 += w14 + x15 += w15 + + out[0], tmp[0] = x0, x0 + out[1], tmp[1] = x1, x1 + out[2], tmp[2] = x2, x2 + out[3], tmp[3] = x3, x3 + out[4], tmp[4] = x4, x4 + out[5], tmp[5] = x5, x5 + out[6], tmp[6] = x6, x6 + out[7], tmp[7] = x7, x7 + out[8], tmp[8] = x8, x8 + out[9], tmp[9] = x9, x9 + out[10], tmp[10] = x10, x10 + out[11], tmp[11] = x11, x11 + out[12], tmp[12] = x12, x12 + out[13], tmp[13] = x13, x13 + out[14], tmp[14] = x14, x14 + out[15], tmp[15] = x15, x15 +} + +func blockMix(tmp *[16]uint32, in, out []uint32, r int) { + blockCopy(tmp[:], in[(2*r-1)*16:], 16) + for i := 0; i < 2*r; i += 2 { + salsaXOR(tmp, in[i*16:], out[i*8:]) + salsaXOR(tmp, in[i*16+16:], out[i*8+r*16:]) + } +} + +func integer(b []uint32, r int) uint64 { + j := (2*r - 1) * 16 + return uint64(b[j]) | uint64(b[j+1])<<32 +} + +func smix(b []byte, r, N int, v, xy []uint32) { + var tmp [16]uint32 + x := xy + y := xy[32*r:] + + j := 0 + for i := 0; i < 32*r; i++ { + x[i] = uint32(b[j]) | uint32(b[j+1])<<8 | uint32(b[j+2])<<16 | uint32(b[j+3])<<24 + j += 4 + } + for i := 0; i < N; i += 2 { + blockCopy(v[i*(32*r):], x, 32*r) + blockMix(&tmp, x, y, r) + + blockCopy(v[(i+1)*(32*r):], y, 32*r) + blockMix(&tmp, y, x, r) + } + for i := 0; i < N; i += 2 { + j := int(integer(x, r) & uint64(N-1)) + blockXOR(x, v[j*(32*r):], 32*r) + blockMix(&tmp, x, y, r) + + j = int(integer(y, r) & uint64(N-1)) + blockXOR(y, v[j*(32*r):], 32*r) + blockMix(&tmp, y, x, r) + } + j = 0 + for _, v := range x[:32*r] { + b[j+0] = byte(v >> 0) + b[j+1] = byte(v >> 8) + b[j+2] = byte(v >> 16) + b[j+3] = byte(v >> 24) + j += 4 + } +} + +// Key derives a key from the password, salt, and cost parameters, returning +// a byte slice of length keyLen that can be used as cryptographic key. +// +// N is a CPU/memory cost parameter, which must be a power of two greater than 1. +// r and p must satisfy r * p < 2³⁰. If the parameters do not satisfy the +// limits, the function returns a nil byte slice and an error. +// +// For example, you can get a derived key for e.g. AES-256 (which needs a +// 32-byte key) by doing: +// +// dk := scrypt.Key([]byte("some password"), salt, 16384, 8, 1, 32) +// +// The recommended parameters for interactive logins as of 2009 are N=16384, +// r=8, p=1. They should be increased as memory latency and CPU parallelism +// increases. Remember to get a good random salt. +func Key(password, salt []byte, N, r, p, keyLen int) ([]byte, error) { + if N <= 1 || N&(N-1) != 0 { + return nil, errors.New("scrypt: N must be > 1 and a power of 2") + } + if uint64(r)*uint64(p) >= 1<<30 || r > maxInt/128/p || r > maxInt/256 || N > maxInt/128/r { + return nil, errors.New("scrypt: parameters are too large") + } + + xy := make([]uint32, 64*r) + v := make([]uint32, 32*N*r) + b := pbkdf2.Key(password, salt, 1, p*128*r, sha256.New) + + for i := 0; i < p; i++ { + smix(b[i*128*r:], r, N, v, xy) + } + + return pbkdf2.Key(password, b, 1, keyLen, sha256.New), nil +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/scrypt/scrypt_test.go b/Godeps/_workspace/src/golang.org/x/crypto/scrypt/scrypt_test.go new file mode 100644 index 000000000..e096c3a31 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/scrypt/scrypt_test.go @@ -0,0 +1,160 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package scrypt + +import ( + "bytes" + "testing" +) + +type testVector struct { + password string + salt string + N, r, p int + output []byte +} + +var good = []testVector{ + { + "password", + "salt", + 2, 10, 10, + []byte{ + 0x48, 0x2c, 0x85, 0x8e, 0x22, 0x90, 0x55, 0xe6, 0x2f, + 0x41, 0xe0, 0xec, 0x81, 0x9a, 0x5e, 0xe1, 0x8b, 0xdb, + 0x87, 0x25, 0x1a, 0x53, 0x4f, 0x75, 0xac, 0xd9, 0x5a, + 0xc5, 0xe5, 0xa, 0xa1, 0x5f, + }, + }, + { + "password", + "salt", + 16, 100, 100, + []byte{ + 0x88, 0xbd, 0x5e, 0xdb, 0x52, 0xd1, 0xdd, 0x0, 0x18, + 0x87, 0x72, 0xad, 0x36, 0x17, 0x12, 0x90, 0x22, 0x4e, + 0x74, 0x82, 0x95, 0x25, 0xb1, 0x8d, 0x73, 0x23, 0xa5, + 0x7f, 0x91, 0x96, 0x3c, 0x37, + }, + }, + { + "this is a long \000 password", + "and this is a long \000 salt", + 16384, 8, 1, + []byte{ + 0xc3, 0xf1, 0x82, 0xee, 0x2d, 0xec, 0x84, 0x6e, 0x70, + 0xa6, 0x94, 0x2f, 0xb5, 0x29, 0x98, 0x5a, 0x3a, 0x09, + 0x76, 0x5e, 0xf0, 0x4c, 0x61, 0x29, 0x23, 0xb1, 0x7f, + 0x18, 0x55, 0x5a, 0x37, 0x07, 0x6d, 0xeb, 0x2b, 0x98, + 0x30, 0xd6, 0x9d, 0xe5, 0x49, 0x26, 0x51, 0xe4, 0x50, + 0x6a, 0xe5, 0x77, 0x6d, 0x96, 0xd4, 0x0f, 0x67, 0xaa, + 0xee, 0x37, 0xe1, 0x77, 0x7b, 0x8a, 0xd5, 0xc3, 0x11, + 0x14, 0x32, 0xbb, 0x3b, 0x6f, 0x7e, 0x12, 0x64, 0x40, + 0x18, 0x79, 0xe6, 0x41, 0xae, + }, + }, + { + "p", + "s", + 2, 1, 1, + []byte{ + 0x48, 0xb0, 0xd2, 0xa8, 0xa3, 0x27, 0x26, 0x11, 0x98, + 0x4c, 0x50, 0xeb, 0xd6, 0x30, 0xaf, 0x52, + }, + }, + + { + "", + "", + 16, 1, 1, + []byte{ + 0x77, 0xd6, 0x57, 0x62, 0x38, 0x65, 0x7b, 0x20, 0x3b, + 0x19, 0xca, 0x42, 0xc1, 0x8a, 0x04, 0x97, 0xf1, 0x6b, + 0x48, 0x44, 0xe3, 0x07, 0x4a, 0xe8, 0xdf, 0xdf, 0xfa, + 0x3f, 0xed, 0xe2, 0x14, 0x42, 0xfc, 0xd0, 0x06, 0x9d, + 0xed, 0x09, 0x48, 0xf8, 0x32, 0x6a, 0x75, 0x3a, 0x0f, + 0xc8, 0x1f, 0x17, 0xe8, 0xd3, 0xe0, 0xfb, 0x2e, 0x0d, + 0x36, 0x28, 0xcf, 0x35, 0xe2, 0x0c, 0x38, 0xd1, 0x89, + 0x06, + }, + }, + { + "password", + "NaCl", + 1024, 8, 16, + []byte{ + 0xfd, 0xba, 0xbe, 0x1c, 0x9d, 0x34, 0x72, 0x00, 0x78, + 0x56, 0xe7, 0x19, 0x0d, 0x01, 0xe9, 0xfe, 0x7c, 0x6a, + 0xd7, 0xcb, 0xc8, 0x23, 0x78, 0x30, 0xe7, 0x73, 0x76, + 0x63, 0x4b, 0x37, 0x31, 0x62, 0x2e, 0xaf, 0x30, 0xd9, + 0x2e, 0x22, 0xa3, 0x88, 0x6f, 0xf1, 0x09, 0x27, 0x9d, + 0x98, 0x30, 0xda, 0xc7, 0x27, 0xaf, 0xb9, 0x4a, 0x83, + 0xee, 0x6d, 0x83, 0x60, 0xcb, 0xdf, 0xa2, 0xcc, 0x06, + 0x40, + }, + }, + { + "pleaseletmein", "SodiumChloride", + 16384, 8, 1, + []byte{ + 0x70, 0x23, 0xbd, 0xcb, 0x3a, 0xfd, 0x73, 0x48, 0x46, + 0x1c, 0x06, 0xcd, 0x81, 0xfd, 0x38, 0xeb, 0xfd, 0xa8, + 0xfb, 0xba, 0x90, 0x4f, 0x8e, 0x3e, 0xa9, 0xb5, 0x43, + 0xf6, 0x54, 0x5d, 0xa1, 0xf2, 0xd5, 0x43, 0x29, 0x55, + 0x61, 0x3f, 0x0f, 0xcf, 0x62, 0xd4, 0x97, 0x05, 0x24, + 0x2a, 0x9a, 0xf9, 0xe6, 0x1e, 0x85, 0xdc, 0x0d, 0x65, + 0x1e, 0x40, 0xdf, 0xcf, 0x01, 0x7b, 0x45, 0x57, 0x58, + 0x87, + }, + }, + /* + // Disabled: needs 1 GiB RAM and takes too long for a simple test. + { + "pleaseletmein", "SodiumChloride", + 1048576, 8, 1, + []byte{ + 0x21, 0x01, 0xcb, 0x9b, 0x6a, 0x51, 0x1a, 0xae, 0xad, + 0xdb, 0xbe, 0x09, 0xcf, 0x70, 0xf8, 0x81, 0xec, 0x56, + 0x8d, 0x57, 0x4a, 0x2f, 0xfd, 0x4d, 0xab, 0xe5, 0xee, + 0x98, 0x20, 0xad, 0xaa, 0x47, 0x8e, 0x56, 0xfd, 0x8f, + 0x4b, 0xa5, 0xd0, 0x9f, 0xfa, 0x1c, 0x6d, 0x92, 0x7c, + 0x40, 0xf4, 0xc3, 0x37, 0x30, 0x40, 0x49, 0xe8, 0xa9, + 0x52, 0xfb, 0xcb, 0xf4, 0x5c, 0x6f, 0xa7, 0x7a, 0x41, + 0xa4, + }, + }, + */ +} + +var bad = []testVector{ + {"p", "s", 0, 1, 1, nil}, // N == 0 + {"p", "s", 1, 1, 1, nil}, // N == 1 + {"p", "s", 7, 8, 1, nil}, // N is not power of 2 + {"p", "s", 16, maxInt / 2, maxInt / 2, nil}, // p * r too large +} + +func TestKey(t *testing.T) { + for i, v := range good { + k, err := Key([]byte(v.password), []byte(v.salt), v.N, v.r, v.p, len(v.output)) + if err != nil { + t.Errorf("%d: got unexpected error: %s", i, err) + } + if !bytes.Equal(k, v.output) { + t.Errorf("%d: expected %x, got %x", i, v.output, k) + } + } + for i, v := range bad { + _, err := Key([]byte(v.password), []byte(v.salt), v.N, v.r, v.p, 32) + if err == nil { + t.Errorf("%d: expected error, got nil", i) + } + } +} + +func BenchmarkKey(b *testing.B) { + for i := 0; i < b.N; i++ { + Key([]byte("password"), []byte("salt"), 16384, 8, 1, 64) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/client.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/client.go new file mode 100644 index 000000000..31a0120c9 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/client.go @@ -0,0 +1,563 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* + Package agent implements a client to an ssh-agent daemon. + +References: + [PROTOCOL.agent]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent?rev=HEAD +*/ +package agent // import "golang.org/x/crypto/ssh/agent" + +import ( + "bytes" + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + "sync" + + "golang.org/x/crypto/ssh" +) + +// Agent represents the capabilities of an ssh-agent. +type Agent interface { + // List returns the identities known to the agent. + List() ([]*Key, error) + + // Sign has the agent sign the data using a protocol 2 key as defined + // in [PROTOCOL.agent] section 2.6.2. + Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) + + // Insert adds a private key to the agent. If a certificate + // is given, that certificate is added as public key. + Add(s interface{}, cert *ssh.Certificate, comment string) error + + // Remove removes all identities with the given public key. + Remove(key ssh.PublicKey) error + + // RemoveAll removes all identities. + RemoveAll() error + + // Lock locks the agent. Sign and Remove will fail, and List will empty an empty list. + Lock(passphrase []byte) error + + // Unlock undoes the effect of Lock + Unlock(passphrase []byte) error + + // Signers returns signers for all the known keys. + Signers() ([]ssh.Signer, error) +} + +// See [PROTOCOL.agent], section 3. +const ( + agentRequestV1Identities = 1 + + // 3.2 Requests from client to agent for protocol 2 key operations + agentAddIdentity = 17 + agentRemoveIdentity = 18 + agentRemoveAllIdentities = 19 + agentAddIdConstrained = 25 + + // 3.3 Key-type independent requests from client to agent + agentAddSmartcardKey = 20 + agentRemoveSmartcardKey = 21 + agentLock = 22 + agentUnlock = 23 + agentAddSmartcardKeyConstrained = 26 + + // 3.7 Key constraint identifiers + agentConstrainLifetime = 1 + agentConstrainConfirm = 2 +) + +// maxAgentResponseBytes is the maximum agent reply size that is accepted. This +// is a sanity check, not a limit in the spec. +const maxAgentResponseBytes = 16 << 20 + +// Agent messages: +// These structures mirror the wire format of the corresponding ssh agent +// messages found in [PROTOCOL.agent]. + +// 3.4 Generic replies from agent to client +const agentFailure = 5 + +type failureAgentMsg struct{} + +const agentSuccess = 6 + +type successAgentMsg struct{} + +// See [PROTOCOL.agent], section 2.5.2. +const agentRequestIdentities = 11 + +type requestIdentitiesAgentMsg struct{} + +// See [PROTOCOL.agent], section 2.5.2. +const agentIdentitiesAnswer = 12 + +type identitiesAnswerAgentMsg struct { + NumKeys uint32 `sshtype:"12"` + Keys []byte `ssh:"rest"` +} + +// See [PROTOCOL.agent], section 2.6.2. +const agentSignRequest = 13 + +type signRequestAgentMsg struct { + KeyBlob []byte `sshtype:"13"` + Data []byte + Flags uint32 +} + +// See [PROTOCOL.agent], section 2.6.2. + +// 3.6 Replies from agent to client for protocol 2 key operations +const agentSignResponse = 14 + +type signResponseAgentMsg struct { + SigBlob []byte `sshtype:"14"` +} + +type publicKey struct { + Format string + Rest []byte `ssh:"rest"` +} + +// Key represents a protocol 2 public key as defined in +// [PROTOCOL.agent], section 2.5.2. +type Key struct { + Format string + Blob []byte + Comment string +} + +func clientErr(err error) error { + return fmt.Errorf("agent: client error: %v", err) +} + +// String returns the storage form of an agent key with the format, base64 +// encoded serialized key, and the comment if it is not empty. +func (k *Key) String() string { + s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob) + + if k.Comment != "" { + s += " " + k.Comment + } + + return s +} + +// Type returns the public key type. +func (k *Key) Type() string { + return k.Format +} + +// Marshal returns key blob to satisfy the ssh.PublicKey interface. +func (k *Key) Marshal() []byte { + return k.Blob +} + +// Verify satisfies the ssh.PublicKey interface, but is not +// implemented for agent keys. +func (k *Key) Verify(data []byte, sig *ssh.Signature) error { + return errors.New("agent: agent key does not know how to verify") +} + +type wireKey struct { + Format string + Rest []byte `ssh:"rest"` +} + +func parseKey(in []byte) (out *Key, rest []byte, err error) { + var record struct { + Blob []byte + Comment string + Rest []byte `ssh:"rest"` + } + + if err := ssh.Unmarshal(in, &record); err != nil { + return nil, nil, err + } + + var wk wireKey + if err := ssh.Unmarshal(record.Blob, &wk); err != nil { + return nil, nil, err + } + + return &Key{ + Format: wk.Format, + Blob: record.Blob, + Comment: record.Comment, + }, record.Rest, nil +} + +// client is a client for an ssh-agent process. +type client struct { + // conn is typically a *net.UnixConn + conn io.ReadWriter + // mu is used to prevent concurrent access to the agent + mu sync.Mutex +} + +// NewClient returns an Agent that talks to an ssh-agent process over +// the given connection. +func NewClient(rw io.ReadWriter) Agent { + return &client{conn: rw} +} + +// call sends an RPC to the agent. On success, the reply is +// unmarshaled into reply and replyType is set to the first byte of +// the reply, which contains the type of the message. +func (c *client) call(req []byte) (reply interface{}, err error) { + c.mu.Lock() + defer c.mu.Unlock() + + msg := make([]byte, 4+len(req)) + binary.BigEndian.PutUint32(msg, uint32(len(req))) + copy(msg[4:], req) + if _, err = c.conn.Write(msg); err != nil { + return nil, clientErr(err) + } + + var respSizeBuf [4]byte + if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil { + return nil, clientErr(err) + } + respSize := binary.BigEndian.Uint32(respSizeBuf[:]) + if respSize > maxAgentResponseBytes { + return nil, clientErr(err) + } + + buf := make([]byte, respSize) + if _, err = io.ReadFull(c.conn, buf); err != nil { + return nil, clientErr(err) + } + reply, err = unmarshal(buf) + if err != nil { + return nil, clientErr(err) + } + return reply, err +} + +func (c *client) simpleCall(req []byte) error { + resp, err := c.call(req) + if err != nil { + return err + } + if _, ok := resp.(*successAgentMsg); ok { + return nil + } + return errors.New("agent: failure") +} + +func (c *client) RemoveAll() error { + return c.simpleCall([]byte{agentRemoveAllIdentities}) +} + +func (c *client) Remove(key ssh.PublicKey) error { + req := ssh.Marshal(&agentRemoveIdentityMsg{ + KeyBlob: key.Marshal(), + }) + return c.simpleCall(req) +} + +func (c *client) Lock(passphrase []byte) error { + req := ssh.Marshal(&agentLockMsg{ + Passphrase: passphrase, + }) + return c.simpleCall(req) +} + +func (c *client) Unlock(passphrase []byte) error { + req := ssh.Marshal(&agentUnlockMsg{ + Passphrase: passphrase, + }) + return c.simpleCall(req) +} + +// List returns the identities known to the agent. +func (c *client) List() ([]*Key, error) { + // see [PROTOCOL.agent] section 2.5.2. + req := []byte{agentRequestIdentities} + + msg, err := c.call(req) + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *identitiesAnswerAgentMsg: + if msg.NumKeys > maxAgentResponseBytes/8 { + return nil, errors.New("agent: too many keys in agent reply") + } + keys := make([]*Key, msg.NumKeys) + data := msg.Keys + for i := uint32(0); i < msg.NumKeys; i++ { + var key *Key + var err error + if key, data, err = parseKey(data); err != nil { + return nil, err + } + keys[i] = key + } + return keys, nil + case *failureAgentMsg: + return nil, errors.New("agent: failed to list keys") + } + panic("unreachable") +} + +// Sign has the agent sign the data using a protocol 2 key as defined +// in [PROTOCOL.agent] section 2.6.2. +func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { + req := ssh.Marshal(signRequestAgentMsg{ + KeyBlob: key.Marshal(), + Data: data, + }) + + msg, err := c.call(req) + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *signResponseAgentMsg: + var sig ssh.Signature + if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil { + return nil, err + } + + return &sig, nil + case *failureAgentMsg: + return nil, errors.New("agent: failed to sign challenge") + } + panic("unreachable") +} + +// unmarshal parses an agent message in packet, returning the parsed +// form and the message type of packet. +func unmarshal(packet []byte) (interface{}, error) { + if len(packet) < 1 { + return nil, errors.New("agent: empty packet") + } + var msg interface{} + switch packet[0] { + case agentFailure: + return new(failureAgentMsg), nil + case agentSuccess: + return new(successAgentMsg), nil + case agentIdentitiesAnswer: + msg = new(identitiesAnswerAgentMsg) + case agentSignResponse: + msg = new(signResponseAgentMsg) + default: + return nil, fmt.Errorf("agent: unknown type tag %d", packet[0]) + } + if err := ssh.Unmarshal(packet, msg); err != nil { + return nil, err + } + return msg, nil +} + +type rsaKeyMsg struct { + Type string `sshtype:"17"` + N *big.Int + E *big.Int + D *big.Int + Iqmp *big.Int // IQMP = Inverse Q Mod P + P *big.Int + Q *big.Int + Comments string +} + +type dsaKeyMsg struct { + Type string `sshtype:"17"` + P *big.Int + Q *big.Int + G *big.Int + Y *big.Int + X *big.Int + Comments string +} + +type ecdsaKeyMsg struct { + Type string `sshtype:"17"` + Curve string + KeyBytes []byte + D *big.Int + Comments string +} + +// Insert adds a private key to the agent. +func (c *client) insertKey(s interface{}, comment string) error { + var req []byte + switch k := s.(type) { + case *rsa.PrivateKey: + if len(k.Primes) != 2 { + return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) + } + k.Precompute() + req = ssh.Marshal(rsaKeyMsg{ + Type: ssh.KeyAlgoRSA, + N: k.N, + E: big.NewInt(int64(k.E)), + D: k.D, + Iqmp: k.Precomputed.Qinv, + P: k.Primes[0], + Q: k.Primes[1], + Comments: comment, + }) + case *dsa.PrivateKey: + req = ssh.Marshal(dsaKeyMsg{ + Type: ssh.KeyAlgoDSA, + P: k.P, + Q: k.Q, + G: k.G, + Y: k.Y, + X: k.X, + Comments: comment, + }) + case *ecdsa.PrivateKey: + nistID := fmt.Sprintf("nistp%d", k.Params().BitSize) + req = ssh.Marshal(ecdsaKeyMsg{ + Type: "ecdsa-sha2-" + nistID, + Curve: nistID, + KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y), + D: k.D, + Comments: comment, + }) + default: + return fmt.Errorf("agent: unsupported key type %T", s) + } + resp, err := c.call(req) + if err != nil { + return err + } + if _, ok := resp.(*successAgentMsg); ok { + return nil + } + return errors.New("agent: failure") +} + +type rsaCertMsg struct { + Type string `sshtype:"17"` + CertBytes []byte + D *big.Int + Iqmp *big.Int // IQMP = Inverse Q Mod P + P *big.Int + Q *big.Int + Comments string +} + +type dsaCertMsg struct { + Type string `sshtype:"17"` + CertBytes []byte + X *big.Int + Comments string +} + +type ecdsaCertMsg struct { + Type string `sshtype:"17"` + CertBytes []byte + D *big.Int + Comments string +} + +// Insert adds a private key to the agent. If a certificate is given, +// that certificate is added instead as public key. +func (c *client) Add(s interface{}, cert *ssh.Certificate, comment string) error { + if cert == nil { + return c.insertKey(s, comment) + } else { + return c.insertCert(s, cert, comment) + } +} + +func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string) error { + var req []byte + switch k := s.(type) { + case *rsa.PrivateKey: + if len(k.Primes) != 2 { + return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) + } + k.Precompute() + req = ssh.Marshal(rsaCertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + D: k.D, + Iqmp: k.Precomputed.Qinv, + P: k.Primes[0], + Q: k.Primes[1], + Comments: comment, + }) + case *dsa.PrivateKey: + req = ssh.Marshal(dsaCertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + X: k.X, + Comments: comment, + }) + case *ecdsa.PrivateKey: + req = ssh.Marshal(ecdsaCertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + D: k.D, + Comments: comment, + }) + default: + return fmt.Errorf("agent: unsupported key type %T", s) + } + + signer, err := ssh.NewSignerFromKey(s) + if err != nil { + return err + } + if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { + return errors.New("agent: signer and cert have different public key") + } + + resp, err := c.call(req) + if err != nil { + return err + } + if _, ok := resp.(*successAgentMsg); ok { + return nil + } + return errors.New("agent: failure") +} + +// Signers provides a callback for client authentication. +func (c *client) Signers() ([]ssh.Signer, error) { + keys, err := c.List() + if err != nil { + return nil, err + } + + var result []ssh.Signer + for _, k := range keys { + result = append(result, &agentKeyringSigner{c, k}) + } + return result, nil +} + +type agentKeyringSigner struct { + agent *client + pub ssh.PublicKey +} + +func (s *agentKeyringSigner) PublicKey() ssh.PublicKey { + return s.pub +} + +func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { + // The agent has its own entropy source, so the rand argument is ignored. + return s.agent.Sign(s.pub, data) +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/client_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/client_test.go new file mode 100644 index 000000000..80e2c2c38 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/client_test.go @@ -0,0 +1,278 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "bytes" + "crypto/rand" + "errors" + "net" + "os" + "os/exec" + "path/filepath" + "strconv" + "testing" + + "golang.org/x/crypto/ssh" +) + +// startAgent executes ssh-agent, and returns a Agent interface to it. +func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) { + if testing.Short() { + // ssh-agent is not always available, and the key + // types supported vary by platform. + t.Skip("skipping test due to -short") + } + + bin, err := exec.LookPath("ssh-agent") + if err != nil { + t.Skip("could not find ssh-agent") + } + + cmd := exec.Command(bin, "-s") + out, err := cmd.Output() + if err != nil { + t.Fatalf("cmd.Output: %v", err) + } + + /* Output looks like: + + SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK; + SSH_AGENT_PID=15542; export SSH_AGENT_PID; + echo Agent pid 15542; + */ + fields := bytes.Split(out, []byte(";")) + line := bytes.SplitN(fields[0], []byte("="), 2) + line[0] = bytes.TrimLeft(line[0], "\n") + if string(line[0]) != "SSH_AUTH_SOCK" { + t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0]) + } + socket = string(line[1]) + + line = bytes.SplitN(fields[2], []byte("="), 2) + line[0] = bytes.TrimLeft(line[0], "\n") + if string(line[0]) != "SSH_AGENT_PID" { + t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2]) + } + pidStr := line[1] + pid, err := strconv.Atoi(string(pidStr)) + if err != nil { + t.Fatalf("Atoi(%q): %v", pidStr, err) + } + + conn, err := net.Dial("unix", string(socket)) + if err != nil { + t.Fatalf("net.Dial: %v", err) + } + + ac := NewClient(conn) + return ac, socket, func() { + proc, _ := os.FindProcess(pid) + if proc != nil { + proc.Kill() + } + conn.Close() + os.RemoveAll(filepath.Dir(socket)) + } +} + +func testAgent(t *testing.T, key interface{}, cert *ssh.Certificate) { + agent, _, cleanup := startAgent(t) + defer cleanup() + + testAgentInterface(t, agent, key, cert) +} + +func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Certificate) { + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + t.Fatalf("NewSignerFromKey(%T): %v", key, err) + } + // The agent should start up empty. + if keys, err := agent.List(); err != nil { + t.Fatalf("RequestIdentities: %v", err) + } else if len(keys) > 0 { + t.Fatalf("got %d keys, want 0: %v", len(keys), keys) + } + + // Attempt to insert the key, with certificate if specified. + var pubKey ssh.PublicKey + if cert != nil { + err = agent.Add(key, cert, "comment") + pubKey = cert + } else { + err = agent.Add(key, nil, "comment") + pubKey = signer.PublicKey() + } + if err != nil { + t.Fatalf("insert(%T): %v", key, err) + } + + // Did the key get inserted successfully? + if keys, err := agent.List(); err != nil { + t.Fatalf("List: %v", err) + } else if len(keys) != 1 { + t.Fatalf("got %v, want 1 key", keys) + } else if keys[0].Comment != "comment" { + t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment") + } else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) { + t.Fatalf("key mismatch") + } + + // Can the agent make a valid signature? + data := []byte("hello") + sig, err := agent.Sign(pubKey, data) + if err != nil { + t.Fatalf("Sign(%s): %v", pubKey.Type(), err) + } + + if err := pubKey.Verify(data, sig); err != nil { + t.Fatalf("Verify(%s): %v", pubKey.Type(), err) + } +} + +func TestAgent(t *testing.T) { + for _, keyType := range []string{"rsa", "dsa", "ecdsa"} { + testAgent(t, testPrivateKeys[keyType], nil) + } +} + +func TestCert(t *testing.T) { + cert := &ssh.Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: ssh.CertTimeInfinity, + CertType: ssh.UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + testAgent(t, testPrivateKeys["rsa"], cert) +} + +// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and +// therefore is buffered (net.Pipe deadlocks if both sides start with +// a write.) +func netPipe() (net.Conn, net.Conn, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer listener.Close() + c1, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1, c2, nil +} + +func TestAuth(t *testing.T) { + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + defer a.Close() + defer b.Close() + + agent, _, cleanup := startAgent(t) + defer cleanup() + + if err := agent.Add(testPrivateKeys["rsa"], nil, "comment"); err != nil { + t.Errorf("Add: %v", err) + } + + serverConf := ssh.ServerConfig{} + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + return nil, nil + } + + return nil, errors.New("pubkey rejected") + } + + go func() { + conn, _, _, err := ssh.NewServerConn(a, &serverConf) + if err != nil { + t.Fatalf("Server: %v", err) + } + conn.Close() + }() + + conf := ssh.ClientConfig{} + conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers)) + conn, _, _, err := ssh.NewClientConn(b, "", &conf) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + conn.Close() +} + +func TestLockClient(t *testing.T) { + agent, _, cleanup := startAgent(t) + defer cleanup() + testLockAgent(agent, t) +} + +func testLockAgent(agent Agent, t *testing.T) { + if err := agent.Add(testPrivateKeys["rsa"], nil, "comment 1"); err != nil { + t.Errorf("Add: %v", err) + } + if err := agent.Add(testPrivateKeys["dsa"], nil, "comment dsa"); err != nil { + t.Errorf("Add: %v", err) + } + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 2 { + t.Errorf("Want 2 keys, got %v", keys) + } + + passphrase := []byte("secret") + if err := agent.Lock(passphrase); err != nil { + t.Errorf("Lock: %v", err) + } + + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 0 { + t.Errorf("Want 0 keys, got %v", keys) + } + + signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"]) + if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil { + t.Fatalf("Sign did not fail") + } + + if err := agent.Remove(signer.PublicKey()); err == nil { + t.Fatalf("Remove did not fail") + } + + if err := agent.RemoveAll(); err == nil { + t.Fatalf("RemoveAll did not fail") + } + + if err := agent.Unlock(nil); err == nil { + t.Errorf("Unlock with wrong passphrase succeeded") + } + if err := agent.Unlock(passphrase); err != nil { + t.Errorf("Unlock: %v", err) + } + + if err := agent.Remove(signer.PublicKey()); err != nil { + t.Fatalf("Remove: %v", err) + } + + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 1 { + t.Errorf("Want 1 keys, got %v", keys) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/forward.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/forward.go new file mode 100644 index 000000000..fd24ba900 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/forward.go @@ -0,0 +1,103 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "errors" + "io" + "net" + "sync" + + "golang.org/x/crypto/ssh" +) + +// RequestAgentForwarding sets up agent forwarding for the session. +// ForwardToAgent or ForwardToRemote should be called to route +// the authentication requests. +func RequestAgentForwarding(session *ssh.Session) error { + ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil) + if err != nil { + return err + } + if !ok { + return errors.New("forwarding request denied") + } + return nil +} + +// ForwardToAgent routes authentication requests to the given keyring. +func ForwardToAgent(client *ssh.Client, keyring Agent) error { + channels := client.HandleChannelOpen(channelType) + if channels == nil { + return errors.New("agent: already have handler for " + channelType) + } + + go func() { + for ch := range channels { + channel, reqs, err := ch.Accept() + if err != nil { + continue + } + go ssh.DiscardRequests(reqs) + go func() { + ServeAgent(keyring, channel) + channel.Close() + }() + } + }() + return nil +} + +const channelType = "auth-agent@openssh.com" + +// ForwardToRemote routes authentication requests to the ssh-agent +// process serving on the given unix socket. +func ForwardToRemote(client *ssh.Client, addr string) error { + channels := client.HandleChannelOpen(channelType) + if channels == nil { + return errors.New("agent: already have handler for " + channelType) + } + conn, err := net.Dial("unix", addr) + if err != nil { + return err + } + conn.Close() + + go func() { + for ch := range channels { + channel, reqs, err := ch.Accept() + if err != nil { + continue + } + go ssh.DiscardRequests(reqs) + go forwardUnixSocket(channel, addr) + } + }() + return nil +} + +func forwardUnixSocket(channel ssh.Channel, addr string) { + conn, err := net.Dial("unix", addr) + if err != nil { + return + } + + var wg sync.WaitGroup + wg.Add(2) + go func() { + io.Copy(conn, channel) + conn.(*net.UnixConn).CloseWrite() + wg.Done() + }() + go func() { + io.Copy(channel, conn) + channel.CloseWrite() + wg.Done() + }() + + wg.Wait() + conn.Close() + channel.Close() +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/keyring.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/keyring.go new file mode 100644 index 000000000..8ac20090a --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/keyring.go @@ -0,0 +1,183 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "bytes" + "crypto/rand" + "crypto/subtle" + "errors" + "fmt" + "sync" + + "golang.org/x/crypto/ssh" +) + +type privKey struct { + signer ssh.Signer + comment string +} + +type keyring struct { + mu sync.Mutex + keys []privKey + + locked bool + passphrase []byte +} + +var errLocked = errors.New("agent: locked") + +// NewKeyring returns an Agent that holds keys in memory. It is safe +// for concurrent use by multiple goroutines. +func NewKeyring() Agent { + return &keyring{} +} + +// RemoveAll removes all identities. +func (r *keyring) RemoveAll() error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + + r.keys = nil + return nil +} + +// Remove removes all identities with the given public key. +func (r *keyring) Remove(key ssh.PublicKey) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + + want := key.Marshal() + found := false + for i := 0; i < len(r.keys); { + if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) { + found = true + r.keys[i] = r.keys[len(r.keys)-1] + r.keys = r.keys[len(r.keys)-1:] + continue + } else { + i++ + } + } + + if !found { + return errors.New("agent: key not found") + } + return nil +} + +// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list. +func (r *keyring) Lock(passphrase []byte) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + + r.locked = true + r.passphrase = passphrase + return nil +} + +// Unlock undoes the effect of Lock +func (r *keyring) Unlock(passphrase []byte) error { + r.mu.Lock() + defer r.mu.Unlock() + if !r.locked { + return errors.New("agent: not locked") + } + if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) { + return fmt.Errorf("agent: incorrect passphrase") + } + + r.locked = false + r.passphrase = nil + return nil +} + +// List returns the identities known to the agent. +func (r *keyring) List() ([]*Key, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + // section 2.7: locked agents return empty. + return nil, nil + } + + var ids []*Key + for _, k := range r.keys { + pub := k.signer.PublicKey() + ids = append(ids, &Key{ + Format: pub.Type(), + Blob: pub.Marshal(), + Comment: k.comment}) + } + return ids, nil +} + +// Insert adds a private key to the keyring. If a certificate +// is given, that certificate is added as public key. +func (r *keyring) Add(priv interface{}, cert *ssh.Certificate, comment string) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + signer, err := ssh.NewSignerFromKey(priv) + + if err != nil { + return err + } + + if cert != nil { + signer, err = ssh.NewCertSigner(cert, signer) + if err != nil { + return err + } + } + + r.keys = append(r.keys, privKey{signer, comment}) + + return nil +} + +// Sign returns a signature for the data. +func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return nil, errLocked + } + + wanted := key.Marshal() + for _, k := range r.keys { + if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) { + return k.signer.Sign(rand.Reader, data) + } + } + return nil, errors.New("not found") +} + +// Signers returns signers for all the known keys. +func (r *keyring) Signers() ([]ssh.Signer, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return nil, errLocked + } + + s := make([]ssh.Signer, 0, len(r.keys)) + for _, k := range r.keys { + s = append(s, k.signer) + } + return s, nil +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/server.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/server.go new file mode 100644 index 000000000..be9df0eb0 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/server.go @@ -0,0 +1,209 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "crypto/rsa" + "encoding/binary" + "fmt" + "io" + "log" + "math/big" + + "golang.org/x/crypto/ssh" +) + +// Server wraps an Agent and uses it to implement the agent side of +// the SSH-agent, wire protocol. +type server struct { + agent Agent +} + +func (s *server) processRequestBytes(reqData []byte) []byte { + rep, err := s.processRequest(reqData) + if err != nil { + if err != errLocked { + // TODO(hanwen): provide better logging interface? + log.Printf("agent %d: %v", reqData[0], err) + } + return []byte{agentFailure} + } + + if err == nil && rep == nil { + return []byte{agentSuccess} + } + + return ssh.Marshal(rep) +} + +func marshalKey(k *Key) []byte { + var record struct { + Blob []byte + Comment string + } + record.Blob = k.Marshal() + record.Comment = k.Comment + + return ssh.Marshal(&record) +} + +type agentV1IdentityMsg struct { + Numkeys uint32 `sshtype:"2"` +} + +type agentRemoveIdentityMsg struct { + KeyBlob []byte `sshtype:"18"` +} + +type agentLockMsg struct { + Passphrase []byte `sshtype:"22"` +} + +type agentUnlockMsg struct { + Passphrase []byte `sshtype:"23"` +} + +func (s *server) processRequest(data []byte) (interface{}, error) { + switch data[0] { + case agentRequestV1Identities: + return &agentV1IdentityMsg{0}, nil + case agentRemoveIdentity: + var req agentRemoveIdentityMsg + if err := ssh.Unmarshal(data, &req); err != nil { + return nil, err + } + + var wk wireKey + if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { + return nil, err + } + + return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob}) + + case agentRemoveAllIdentities: + return nil, s.agent.RemoveAll() + + case agentLock: + var req agentLockMsg + if err := ssh.Unmarshal(data, &req); err != nil { + return nil, err + } + + return nil, s.agent.Lock(req.Passphrase) + + case agentUnlock: + var req agentLockMsg + if err := ssh.Unmarshal(data, &req); err != nil { + return nil, err + } + return nil, s.agent.Unlock(req.Passphrase) + + case agentSignRequest: + var req signRequestAgentMsg + if err := ssh.Unmarshal(data, &req); err != nil { + return nil, err + } + + var wk wireKey + if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { + return nil, err + } + + k := &Key{ + Format: wk.Format, + Blob: req.KeyBlob, + } + + sig, err := s.agent.Sign(k, req.Data) // TODO(hanwen): flags. + if err != nil { + return nil, err + } + return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil + case agentRequestIdentities: + keys, err := s.agent.List() + if err != nil { + return nil, err + } + + rep := identitiesAnswerAgentMsg{ + NumKeys: uint32(len(keys)), + } + for _, k := range keys { + rep.Keys = append(rep.Keys, marshalKey(k)...) + } + return rep, nil + case agentAddIdentity: + return nil, s.insertIdentity(data) + } + + return nil, fmt.Errorf("unknown opcode %d", data[0]) +} + +func (s *server) insertIdentity(req []byte) error { + var record struct { + Type string `sshtype:"17"` + Rest []byte `ssh:"rest"` + } + if err := ssh.Unmarshal(req, &record); err != nil { + return err + } + + switch record.Type { + case ssh.KeyAlgoRSA: + var k rsaKeyMsg + if err := ssh.Unmarshal(req, &k); err != nil { + return err + } + + priv := rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + E: int(k.E.Int64()), + N: k.N, + }, + D: k.D, + Primes: []*big.Int{k.P, k.Q}, + } + priv.Precompute() + + return s.agent.Add(&priv, nil, k.Comments) + } + return fmt.Errorf("not implemented: %s", record.Type) +} + +// ServeAgent serves the agent protocol on the given connection. It +// returns when an I/O error occurs. +func ServeAgent(agent Agent, c io.ReadWriter) error { + s := &server{agent} + + var length [4]byte + for { + if _, err := io.ReadFull(c, length[:]); err != nil { + return err + } + l := binary.BigEndian.Uint32(length[:]) + if l > maxAgentResponseBytes { + // We also cap requests. + return fmt.Errorf("agent: request too large: %d", l) + } + + req := make([]byte, l) + if _, err := io.ReadFull(c, req); err != nil { + return err + } + + repData := s.processRequestBytes(req) + if len(repData) > maxAgentResponseBytes { + return fmt.Errorf("agent: reply too large: %d bytes", len(repData)) + } + + binary.BigEndian.PutUint32(length[:], uint32(len(repData))) + if _, err := c.Write(length[:]); err != nil { + return err + } + if _, err := c.Write(repData); err != nil { + return err + } + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/server_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/server_test.go new file mode 100644 index 000000000..def5f8ccc --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/server_test.go @@ -0,0 +1,77 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "testing" + + "golang.org/x/crypto/ssh" +) + +func TestServer(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + client := NewClient(c1) + + go ServeAgent(NewKeyring(), c2) + + testAgentInterface(t, client, testPrivateKeys["rsa"], nil) +} + +func TestLockServer(t *testing.T) { + testLockAgent(NewKeyring(), t) +} + +func TestSetupForwardAgent(t *testing.T) { + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + defer a.Close() + defer b.Close() + + _, socket, cleanup := startAgent(t) + defer cleanup() + + serverConf := ssh.ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["rsa"]) + incoming := make(chan *ssh.ServerConn, 1) + go func() { + conn, _, _, err := ssh.NewServerConn(a, &serverConf) + if err != nil { + t.Fatalf("Server: %v", err) + } + incoming <- conn + }() + + conf := ssh.ClientConfig{} + conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + client := ssh.NewClient(conn, chans, reqs) + + if err := ForwardToRemote(client, socket); err != nil { + t.Fatalf("SetupForwardAgent: %v", err) + } + + server := <-incoming + ch, reqs, err := server.OpenChannel(channelType, nil) + if err != nil { + t.Fatalf("OpenChannel(%q): %v", channelType, err) + } + go ssh.DiscardRequests(reqs) + + agentClient := NewClient(ch) + testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil) + conn.Close() +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/testdata_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/testdata_test.go new file mode 100644 index 000000000..b7a8781e1 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/agent/testdata_test.go @@ -0,0 +1,64 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package agent + +import ( + "crypto/rand" + "fmt" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]ssh.Signer + testPublicKeys map[string]ssh.PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]ssh.Signer, n) + testPublicKeys = make(map[string]ssh.PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &ssh.Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/benchmark_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/benchmark_test.go new file mode 100644 index 000000000..d9f7eb9b6 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/benchmark_test.go @@ -0,0 +1,122 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "io" + "net" + "testing" +) + +type server struct { + *ServerConn + chans <-chan NewChannel +} + +func newServer(c net.Conn, conf *ServerConfig) (*server, error) { + sconn, chans, reqs, err := NewServerConn(c, conf) + if err != nil { + return nil, err + } + go DiscardRequests(reqs) + return &server{sconn, chans}, nil +} + +func (s *server) Accept() (NewChannel, error) { + n, ok := <-s.chans + if !ok { + return nil, io.EOF + } + return n, nil +} + +func sshPipe() (Conn, *server, error) { + c1, c2, err := netPipe() + if err != nil { + return nil, nil, err + } + + clientConf := ClientConfig{ + User: "user", + } + serverConf := ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["ecdsa"]) + done := make(chan *server, 1) + go func() { + server, err := newServer(c2, &serverConf) + if err != nil { + done <- nil + } + done <- server + }() + + client, _, reqs, err := NewClientConn(c1, "", &clientConf) + if err != nil { + return nil, nil, err + } + + server := <-done + if server == nil { + return nil, nil, errors.New("server handshake failed.") + } + go DiscardRequests(reqs) + + return client, server, nil +} + +func BenchmarkEndToEnd(b *testing.B) { + b.StopTimer() + + client, server, err := sshPipe() + if err != nil { + b.Fatalf("sshPipe: %v", err) + } + + defer client.Close() + defer server.Close() + + size := (1 << 20) + input := make([]byte, size) + output := make([]byte, size) + b.SetBytes(int64(size)) + done := make(chan int, 1) + + go func() { + newCh, err := server.Accept() + if err != nil { + b.Fatalf("Client: %v", err) + } + ch, incoming, err := newCh.Accept() + go DiscardRequests(incoming) + for i := 0; i < b.N; i++ { + if _, err := io.ReadFull(ch, output); err != nil { + b.Fatalf("ReadFull: %v", err) + } + } + ch.Close() + done <- 1 + }() + + ch, in, err := client.OpenChannel("speed", nil) + if err != nil { + b.Fatalf("OpenChannel: %v", err) + } + go DiscardRequests(in) + + b.ResetTimer() + b.StartTimer() + for i := 0; i < b.N; i++ { + if _, err := ch.Write(input); err != nil { + b.Fatalf("WriteFull: %v", err) + } + } + ch.Close() + b.StopTimer() + + <-done +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/buffer.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/buffer.go new file mode 100644 index 000000000..6931b5114 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/buffer.go @@ -0,0 +1,98 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "sync" +) + +// buffer provides a linked list buffer for data exchange +// between producer and consumer. Theoretically the buffer is +// of unlimited capacity as it does no allocation of its own. +type buffer struct { + // protects concurrent access to head, tail and closed + *sync.Cond + + head *element // the buffer that will be read first + tail *element // the buffer that will be read last + + closed bool +} + +// An element represents a single link in a linked list. +type element struct { + buf []byte + next *element +} + +// newBuffer returns an empty buffer that is not closed. +func newBuffer() *buffer { + e := new(element) + b := &buffer{ + Cond: newCond(), + head: e, + tail: e, + } + return b +} + +// write makes buf available for Read to receive. +// buf must not be modified after the call to write. +func (b *buffer) write(buf []byte) { + b.Cond.L.Lock() + e := &element{buf: buf} + b.tail.next = e + b.tail = e + b.Cond.Signal() + b.Cond.L.Unlock() +} + +// eof closes the buffer. Reads from the buffer once all +// the data has been consumed will receive os.EOF. +func (b *buffer) eof() error { + b.Cond.L.Lock() + b.closed = true + b.Cond.Signal() + b.Cond.L.Unlock() + return nil +} + +// Read reads data from the internal buffer in buf. Reads will block +// if no data is available, or until the buffer is closed. +func (b *buffer) Read(buf []byte) (n int, err error) { + b.Cond.L.Lock() + defer b.Cond.L.Unlock() + + for len(buf) > 0 { + // if there is data in b.head, copy it + if len(b.head.buf) > 0 { + r := copy(buf, b.head.buf) + buf, b.head.buf = buf[r:], b.head.buf[r:] + n += r + continue + } + // if there is a next buffer, make it the head + if len(b.head.buf) == 0 && b.head != b.tail { + b.head = b.head.next + continue + } + + // if at least one byte has been copied, return + if n > 0 { + break + } + + // if nothing was read, and there is nothing outstanding + // check to see if the buffer is closed. + if b.closed { + err = io.EOF + break + } + // out of buffers, wait for producer + b.Cond.Wait() + } + return +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/buffer_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/buffer_test.go new file mode 100644 index 000000000..d5781cb3d --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/buffer_test.go @@ -0,0 +1,87 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "testing" +) + +var alphabet = []byte("abcdefghijklmnopqrstuvwxyz") + +func TestBufferReadwrite(t *testing.T) { + b := newBuffer() + b.write(alphabet[:10]) + r, _ := b.Read(make([]byte, 10)) + if r != 10 { + t.Fatalf("Expected written == read == 10, written: 10, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:5]) + r, _ = b.Read(make([]byte, 10)) + if r != 5 { + t.Fatalf("Expected written == read == 5, written: 5, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:10]) + r, _ = b.Read(make([]byte, 5)) + if r != 5 { + t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:5]) + b.write(alphabet[5:15]) + r, _ = b.Read(make([]byte, 10)) + r2, _ := b.Read(make([]byte, 10)) + if r != 10 || r2 != 5 || 15 != r+r2 { + t.Fatal("Expected written == read == 15") + } +} + +func TestBufferClose(t *testing.T) { + b := newBuffer() + b.write(alphabet[:10]) + b.eof() + _, err := b.Read(make([]byte, 5)) + if err != nil { + t.Fatal("expected read of 5 to not return EOF") + } + b = newBuffer() + b.write(alphabet[:10]) + b.eof() + r, err := b.Read(make([]byte, 5)) + r2, err2 := b.Read(make([]byte, 10)) + if r != 5 || r2 != 5 || err != nil || err2 != nil { + t.Fatal("expected reads of 5 and 5") + } + + b = newBuffer() + b.write(alphabet[:10]) + b.eof() + r, err = b.Read(make([]byte, 5)) + r2, err2 = b.Read(make([]byte, 10)) + r3, err3 := b.Read(make([]byte, 10)) + if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF { + t.Fatal("expected reads of 5 and 5 and 0, with EOF") + } + + b = newBuffer() + b.write(make([]byte, 5)) + b.write(make([]byte, 10)) + b.eof() + r, err = b.Read(make([]byte, 9)) + r2, err2 = b.Read(make([]byte, 3)) + r3, err3 = b.Read(make([]byte, 3)) + r4, err4 := b.Read(make([]byte, 10)) + if err != nil || err2 != nil || err3 != nil || err4 != io.EOF { + t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4) + } + if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 { + t.Fatal("Expected written == read == 15", r, r2, r3, r4) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/certs.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/certs.go new file mode 100644 index 000000000..a7426b602 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/certs.go @@ -0,0 +1,501 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "sort" + "time" +) + +// These constants from [PROTOCOL.certkeys] represent the algorithm names +// for certificate types supported by this package. +const ( + CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com" + CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com" + CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com" + CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com" + CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com" +) + +// Certificate types distinguish between host and user +// certificates. The values can be set in the CertType field of +// Certificate. +const ( + UserCert = 1 + HostCert = 2 +) + +// Signature represents a cryptographic signature. +type Signature struct { + Format string + Blob []byte +} + +// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that +// a certificate does not expire. +const CertTimeInfinity = 1<<64 - 1 + +// An Certificate represents an OpenSSH certificate as defined in +// [PROTOCOL.certkeys]?rev=1.8. +type Certificate struct { + Nonce []byte + Key PublicKey + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []string + ValidAfter uint64 + ValidBefore uint64 + Permissions + Reserved []byte + SignatureKey PublicKey + Signature *Signature +} + +// genericCertData holds the key-independent part of the certificate data. +// Overall, certificates contain an nonce, public key fields and +// key-independent fields. +type genericCertData struct { + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []byte + ValidAfter uint64 + ValidBefore uint64 + CriticalOptions []byte + Extensions []byte + Reserved []byte + SignatureKey []byte + Signature []byte +} + +func marshalStringList(namelist []string) []byte { + var to []byte + for _, name := range namelist { + s := struct{ N string }{name} + to = append(to, Marshal(&s)...) + } + return to +} + +type optionsTuple struct { + Key string + Value []byte +} + +type optionsTupleValue struct { + Value string +} + +// serialize a map of critical options or extensions +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty string value +func marshalTuples(tups map[string]string) []byte { + keys := make([]string, 0, len(tups)) + for key := range tups { + keys = append(keys, key) + } + sort.Strings(keys) + + var ret []byte + for _, key := range keys { + s := optionsTuple{Key: key} + if value := tups[key]; len(value) > 0 { + s.Value = Marshal(&optionsTupleValue{value}) + } + ret = append(ret, Marshal(&s)...) + } + return ret +} + +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty option value +func parseTuples(in []byte) (map[string]string, error) { + tups := map[string]string{} + var lastKey string + var haveLastKey bool + + for len(in) > 0 { + var key, val, extra []byte + var ok bool + + if key, in, ok = parseString(in); !ok { + return nil, errShortRead + } + keyStr := string(key) + // according to [PROTOCOL.certkeys], the names must be in + // lexical order. + if haveLastKey && keyStr <= lastKey { + return nil, fmt.Errorf("ssh: certificate options are not in lexical order") + } + lastKey, haveLastKey = keyStr, true + // the next field is a data field, which if non-empty has a string embedded + if val, in, ok = parseString(in); !ok { + return nil, errShortRead + } + if len(val) > 0 { + val, extra, ok = parseString(val) + if !ok { + return nil, errShortRead + } + if len(extra) > 0 { + return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value") + } + tups[keyStr] = string(val) + } else { + tups[keyStr] = "" + } + } + return tups, nil +} + +func parseCert(in []byte, privAlgo string) (*Certificate, error) { + nonce, rest, ok := parseString(in) + if !ok { + return nil, errShortRead + } + + key, rest, err := parsePubKey(rest, privAlgo) + if err != nil { + return nil, err + } + + var g genericCertData + if err := Unmarshal(rest, &g); err != nil { + return nil, err + } + + c := &Certificate{ + Nonce: nonce, + Key: key, + Serial: g.Serial, + CertType: g.CertType, + KeyId: g.KeyId, + ValidAfter: g.ValidAfter, + ValidBefore: g.ValidBefore, + } + + for principals := g.ValidPrincipals; len(principals) > 0; { + principal, rest, ok := parseString(principals) + if !ok { + return nil, errShortRead + } + c.ValidPrincipals = append(c.ValidPrincipals, string(principal)) + principals = rest + } + + c.CriticalOptions, err = parseTuples(g.CriticalOptions) + if err != nil { + return nil, err + } + c.Extensions, err = parseTuples(g.Extensions) + if err != nil { + return nil, err + } + c.Reserved = g.Reserved + k, err := ParsePublicKey(g.SignatureKey) + if err != nil { + return nil, err + } + + c.SignatureKey = k + c.Signature, rest, ok = parseSignatureBody(g.Signature) + if !ok || len(rest) > 0 { + return nil, errors.New("ssh: signature parse error") + } + + return c, nil +} + +type openSSHCertSigner struct { + pub *Certificate + signer Signer +} + +// NewCertSigner returns a Signer that signs with the given Certificate, whose +// private key is held by signer. It returns an error if the public key in cert +// doesn't match the key used by signer. +func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { + if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { + return nil, errors.New("ssh: signer and cert have different public key") + } + + return &openSSHCertSigner{cert, signer}, nil +} + +func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.signer.Sign(rand, data) +} + +func (s *openSSHCertSigner) PublicKey() PublicKey { + return s.pub +} + +const sourceAddressCriticalOption = "source-address" + +// CertChecker does the work of verifying a certificate. Its methods +// can be plugged into ClientConfig.HostKeyCallback and +// ServerConfig.PublicKeyCallback. For the CertChecker to work, +// minimally, the IsAuthority callback should be set. +type CertChecker struct { + // SupportedCriticalOptions lists the CriticalOptions that the + // server application layer understands. These are only used + // for user certificates. + SupportedCriticalOptions []string + + // IsAuthority should return true if the key is recognized as + // an authority. This allows for certificates to be signed by other + // certificates. + IsAuthority func(auth PublicKey) bool + + // Clock is used for verifying time stamps. If nil, time.Now + // is used. + Clock func() time.Time + + // UserKeyFallback is called when CertChecker.Authenticate encounters a + // public key that is not a certificate. It must implement validation + // of user keys or else, if nil, all such keys are rejected. + UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // HostKeyFallback is called when CertChecker.CheckHostKey encounters a + // public key that is not a certificate. It must implement host key + // validation or else, if nil, all such keys are rejected. + HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error + + // IsRevoked is called for each certificate so that revocation checking + // can be implemented. It should return true if the given certificate + // is revoked and false otherwise. If nil, no certificates are + // considered to have been revoked. + IsRevoked func(cert *Certificate) bool +} + +// CheckHostKey checks a host key certificate. This method can be +// plugged into ClientConfig.HostKeyCallback. +func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error { + cert, ok := key.(*Certificate) + if !ok { + if c.HostKeyFallback != nil { + return c.HostKeyFallback(addr, remote, key) + } + return errors.New("ssh: non-certificate host key") + } + if cert.CertType != HostCert { + return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) + } + + return c.CheckCert(addr, cert) +} + +// Authenticate checks a user certificate. Authenticate can be used as +// a value for ServerConfig.PublicKeyCallback. +func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) { + cert, ok := pubKey.(*Certificate) + if !ok { + if c.UserKeyFallback != nil { + return c.UserKeyFallback(conn, pubKey) + } + return nil, errors.New("ssh: normal key pairs not accepted") + } + + if cert.CertType != UserCert { + return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) + } + + if err := c.CheckCert(conn.User(), cert); err != nil { + return nil, err + } + + return &cert.Permissions, nil +} + +// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and +// the signature of the certificate. +func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { + if c.IsRevoked != nil && c.IsRevoked(cert) { + return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial) + } + + for opt, _ := range cert.CriticalOptions { + // sourceAddressCriticalOption will be enforced by + // serverAuthenticate + if opt == sourceAddressCriticalOption { + continue + } + + found := false + for _, supp := range c.SupportedCriticalOptions { + if supp == opt { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt) + } + } + + if len(cert.ValidPrincipals) > 0 { + // By default, certs are valid for all users/hosts. + found := false + for _, p := range cert.ValidPrincipals { + if p == principal { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals) + } + } + + if !c.IsAuthority(cert.SignatureKey) { + return fmt.Errorf("ssh: certificate signed by unrecognized authority") + } + + clock := c.Clock + if clock == nil { + clock = time.Now + } + + unixNow := clock().Unix() + if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { + return fmt.Errorf("ssh: cert is not yet valid") + } + if before := int64(cert.ValidBefore); cert.ValidBefore != CertTimeInfinity && (unixNow >= before || before < 0) { + return fmt.Errorf("ssh: cert has expired") + } + if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { + return fmt.Errorf("ssh: certificate signature does not verify") + } + + return nil +} + +// SignCert sets c.SignatureKey to the authority's public key and stores a +// Signature, by authority, in the certificate. +func (c *Certificate) SignCert(rand io.Reader, authority Signer) error { + c.Nonce = make([]byte, 32) + if _, err := io.ReadFull(rand, c.Nonce); err != nil { + return err + } + c.SignatureKey = authority.PublicKey() + + sig, err := authority.Sign(rand, c.bytesForSigning()) + if err != nil { + return err + } + c.Signature = sig + return nil +} + +var certAlgoNames = map[string]string{ + KeyAlgoRSA: CertAlgoRSAv01, + KeyAlgoDSA: CertAlgoDSAv01, + KeyAlgoECDSA256: CertAlgoECDSA256v01, + KeyAlgoECDSA384: CertAlgoECDSA384v01, + KeyAlgoECDSA521: CertAlgoECDSA521v01, +} + +// certToPrivAlgo returns the underlying algorithm for a certificate algorithm. +// Panics if a non-certificate algorithm is passed. +func certToPrivAlgo(algo string) string { + for privAlgo, pubAlgo := range certAlgoNames { + if pubAlgo == algo { + return privAlgo + } + } + panic("unknown cert algorithm") +} + +func (cert *Certificate) bytesForSigning() []byte { + c2 := *cert + c2.Signature = nil + out := c2.Marshal() + // Drop trailing signature length. + return out[:len(out)-4] +} + +// Marshal serializes c into OpenSSH's wire format. It is part of the +// PublicKey interface. +func (c *Certificate) Marshal() []byte { + generic := genericCertData{ + Serial: c.Serial, + CertType: c.CertType, + KeyId: c.KeyId, + ValidPrincipals: marshalStringList(c.ValidPrincipals), + ValidAfter: uint64(c.ValidAfter), + ValidBefore: uint64(c.ValidBefore), + CriticalOptions: marshalTuples(c.CriticalOptions), + Extensions: marshalTuples(c.Extensions), + Reserved: c.Reserved, + SignatureKey: c.SignatureKey.Marshal(), + } + if c.Signature != nil { + generic.Signature = Marshal(c.Signature) + } + genericBytes := Marshal(&generic) + keyBytes := c.Key.Marshal() + _, keyBytes, _ = parseString(keyBytes) + prefix := Marshal(&struct { + Name string + Nonce []byte + Key []byte `ssh:"rest"` + }{c.Type(), c.Nonce, keyBytes}) + + result := make([]byte, 0, len(prefix)+len(genericBytes)) + result = append(result, prefix...) + result = append(result, genericBytes...) + return result +} + +// Type returns the key name. It is part of the PublicKey interface. +func (c *Certificate) Type() string { + algo, ok := certAlgoNames[c.Key.Type()] + if !ok { + panic("unknown cert key type") + } + return algo +} + +// Verify verifies a signature against the certificate's public +// key. It is part of the PublicKey interface. +func (c *Certificate) Verify(data []byte, sig *Signature) error { + return c.Key.Verify(data, sig) +} + +func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) { + format, in, ok := parseString(in) + if !ok { + return + } + + out = &Signature{ + Format: string(format), + } + + if out.Blob, in, ok = parseString(in); !ok { + return + } + + return out, in, ok +} + +func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) { + sigBytes, rest, ok := parseString(in) + if !ok { + return + } + + out, trailing, ok := parseSignatureBody(sigBytes) + if !ok || len(trailing) > 0 { + return nil, nil, false + } + return +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/certs_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/certs_test.go new file mode 100644 index 000000000..33236f5e8 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/certs_test.go @@ -0,0 +1,211 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "reflect" + "testing" + "time" +) + +// Cert generated by ssh-keygen 6.0p1 Debian-4. +// % ssh-keygen -s ca-key -I test user-key +const exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=` + +func TestParseCert(t *testing.T) { + authKeyBytes := []byte(exampleSSHCert) + + key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + if len(rest) > 0 { + t.Errorf("rest: got %q, want empty", rest) + } + + if _, ok := key.(*Certificate); !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + + marshaled := MarshalAuthorizedKey(key) + // Before comparison, remove the trailing newline that + // MarshalAuthorizedKey adds. + marshaled = marshaled[:len(marshaled)-1] + if !bytes.Equal(authKeyBytes, marshaled) { + t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) + } +} + +// Cert generated by ssh-keygen OpenSSH_6.8p1 OS X 10.10.3 +// % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub +// user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN +// Critical Options: +// force-command /bin/sleep +// source-address 192.168.1.0/24 +// Extensions: +// permit-X11-forwarding +// permit-agent-forwarding +// permit-port-forwarding +// permit-pty +// permit-user-rc +const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ` + +func TestParseCertWithOptions(t *testing.T) { + opts := map[string]string{ + "source-address": "192.168.1.0/24", + "force-command": "/bin/sleep", + } + exts := map[string]string{ + "permit-X11-forwarding": "", + "permit-agent-forwarding": "", + "permit-port-forwarding": "", + "permit-pty": "", + "permit-user-rc": "", + } + authKeyBytes := []byte(exampleSSHCertWithOptions) + + key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + if len(rest) > 0 { + t.Errorf("rest: got %q, want empty", rest) + } + cert, ok := key.(*Certificate) + if !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + if !reflect.DeepEqual(cert.CriticalOptions, opts) { + t.Errorf("unexpected critical options - got %v, want %v", cert.CriticalOptions, opts) + } + if !reflect.DeepEqual(cert.Extensions, exts) { + t.Errorf("unexpected Extensions - got %v, want %v", cert.Extensions, exts) + } + marshaled := MarshalAuthorizedKey(key) + // Before comparison, remove the trailing newline that + // MarshalAuthorizedKey adds. + marshaled = marshaled[:len(marshaled)-1] + if !bytes.Equal(authKeyBytes, marshaled) { + t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) + } +} + +func TestValidateCert(t *testing.T) { + key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert)) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + validCert, ok := key.(*Certificate) + if !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + checker := CertChecker{} + checker.IsAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) + } + + if err := checker.CheckCert("user", validCert); err != nil { + t.Errorf("Unable to validate certificate: %v", err) + } + invalidCert := &Certificate{ + Key: testPublicKeys["rsa"], + SignatureKey: testPublicKeys["ecdsa"], + ValidBefore: CertTimeInfinity, + Signature: &Signature{}, + } + if err := checker.CheckCert("user", invalidCert); err == nil { + t.Error("Invalid cert signature passed validation") + } +} + +func TestValidateCertTime(t *testing.T) { + cert := Certificate{ + ValidPrincipals: []string{"user"}, + Key: testPublicKeys["rsa"], + ValidAfter: 50, + ValidBefore: 100, + } + + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + for ts, ok := range map[int64]bool{ + 25: false, + 50: true, + 99: true, + 100: false, + 125: false, + } { + checker := CertChecker{ + Clock: func() time.Time { return time.Unix(ts, 0) }, + } + checker.IsAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), + testPublicKeys["ecdsa"].Marshal()) + } + + if v := checker.CheckCert("user", &cert); (v == nil) != ok { + t.Errorf("Authenticate(%d): %v", ts, v) + } + } +} + +// TODO(hanwen): tests for +// +// host keys: +// * fallbacks + +func TestHostKeyCert(t *testing.T) { + cert := &Certificate{ + ValidPrincipals: []string{"hostname", "hostname.domain"}, + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: HostCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + checker := &CertChecker{ + IsAuthority: func(p PublicKey) bool { + return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) + }, + } + + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Errorf("NewCertSigner: %v", err) + } + + for _, name := range []string{"hostname", "otherhost"} { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go func() { + conf := ServerConfig{ + NoClientAuth: true, + } + conf.AddHostKey(certSigner) + _, _, _, err := NewServerConn(c1, &conf) + if err != nil { + t.Fatalf("NewServerConn: %v", err) + } + }() + + config := &ClientConfig{ + User: "user", + HostKeyCallback: checker.CheckHostKey, + } + _, _, _, err = NewClientConn(c2, name, config) + + succeed := name == "hostname" + if (err == nil) != succeed { + t.Fatalf("NewClientConn(%q): %v", name, err) + } + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/channel.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/channel.go new file mode 100644 index 000000000..5403c7e45 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/channel.go @@ -0,0 +1,631 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "sync" +) + +const ( + minPacketLength = 9 + // channelMaxPacket contains the maximum number of bytes that will be + // sent in a single packet. As per RFC 4253, section 6.1, 32k is also + // the minimum. + channelMaxPacket = 1 << 15 + // We follow OpenSSH here. + channelWindowSize = 64 * channelMaxPacket +) + +// NewChannel represents an incoming request to a channel. It must either be +// accepted for use by calling Accept, or rejected by calling Reject. +type NewChannel interface { + // Accept accepts the channel creation request. It returns the Channel + // and a Go channel containing SSH requests. The Go channel must be + // serviced otherwise the Channel will hang. + Accept() (Channel, <-chan *Request, error) + + // Reject rejects the channel creation request. After calling + // this, no other methods on the Channel may be called. + Reject(reason RejectionReason, message string) error + + // ChannelType returns the type of the channel, as supplied by the + // client. + ChannelType() string + + // ExtraData returns the arbitrary payload for this channel, as supplied + // by the client. This data is specific to the channel type. + ExtraData() []byte +} + +// A Channel is an ordered, reliable, flow-controlled, duplex stream +// that is multiplexed over an SSH connection. +type Channel interface { + // Read reads up to len(data) bytes from the channel. + Read(data []byte) (int, error) + + // Write writes len(data) bytes to the channel. + Write(data []byte) (int, error) + + // Close signals end of channel use. No data may be sent after this + // call. + Close() error + + // CloseWrite signals the end of sending in-band + // data. Requests may still be sent, and the other side may + // still send data + CloseWrite() error + + // SendRequest sends a channel request. If wantReply is true, + // it will wait for a reply and return the result as a + // boolean, otherwise the return value will be false. Channel + // requests are out-of-band messages so they may be sent even + // if the data stream is closed or blocked by flow control. + SendRequest(name string, wantReply bool, payload []byte) (bool, error) + + // Stderr returns an io.ReadWriter that writes to this channel + // with the extended data type set to stderr. Stderr may + // safely be read and written from a different goroutine than + // Read and Write respectively. + Stderr() io.ReadWriter +} + +// Request is a request sent outside of the normal stream of +// data. Requests can either be specific to an SSH channel, or they +// can be global. +type Request struct { + Type string + WantReply bool + Payload []byte + + ch *channel + mux *mux +} + +// Reply sends a response to a request. It must be called for all requests +// where WantReply is true and is a no-op otherwise. The payload argument is +// ignored for replies to channel-specific requests. +func (r *Request) Reply(ok bool, payload []byte) error { + if !r.WantReply { + return nil + } + + if r.ch == nil { + return r.mux.ackRequest(ok, payload) + } + + return r.ch.ackRequest(ok) +} + +// RejectionReason is an enumeration used when rejecting channel creation +// requests. See RFC 4254, section 5.1. +type RejectionReason uint32 + +const ( + Prohibited RejectionReason = iota + 1 + ConnectionFailed + UnknownChannelType + ResourceShortage +) + +// String converts the rejection reason to human readable form. +func (r RejectionReason) String() string { + switch r { + case Prohibited: + return "administratively prohibited" + case ConnectionFailed: + return "connect failed" + case UnknownChannelType: + return "unknown channel type" + case ResourceShortage: + return "resource shortage" + } + return fmt.Sprintf("unknown reason %d", int(r)) +} + +func min(a uint32, b int) uint32 { + if a < uint32(b) { + return a + } + return uint32(b) +} + +type channelDirection uint8 + +const ( + channelInbound channelDirection = iota + channelOutbound +) + +// channel is an implementation of the Channel interface that works +// with the mux class. +type channel struct { + // R/O after creation + chanType string + extraData []byte + localId, remoteId uint32 + + // maxIncomingPayload and maxRemotePayload are the maximum + // payload sizes of normal and extended data packets for + // receiving and sending, respectively. The wire packet will + // be 9 or 13 bytes larger (excluding encryption overhead). + maxIncomingPayload uint32 + maxRemotePayload uint32 + + mux *mux + + // decided is set to true if an accept or reject message has been sent + // (for outbound channels) or received (for inbound channels). + decided bool + + // direction contains either channelOutbound, for channels created + // locally, or channelInbound, for channels created by the peer. + direction channelDirection + + // Pending internal channel messages. + msg chan interface{} + + // Since requests have no ID, there can be only one request + // with WantReply=true outstanding. This lock is held by a + // goroutine that has such an outgoing request pending. + sentRequestMu sync.Mutex + + incomingRequests chan *Request + + sentEOF bool + + // thread-safe data + remoteWin window + pending *buffer + extPending *buffer + + // windowMu protects myWindow, the flow-control window. + windowMu sync.Mutex + myWindow uint32 + + // writeMu serializes calls to mux.conn.writePacket() and + // protects sentClose and packetPool. This mutex must be + // different from windowMu, as writePacket can block if there + // is a key exchange pending. + writeMu sync.Mutex + sentClose bool + + // packetPool has a buffer for each extended channel ID to + // save allocations during writes. + packetPool map[uint32][]byte +} + +// writePacket sends a packet. If the packet is a channel close, it updates +// sentClose. This method takes the lock c.writeMu. +func (c *channel) writePacket(packet []byte) error { + c.writeMu.Lock() + if c.sentClose { + c.writeMu.Unlock() + return io.EOF + } + c.sentClose = (packet[0] == msgChannelClose) + err := c.mux.conn.writePacket(packet) + c.writeMu.Unlock() + return err +} + +func (c *channel) sendMessage(msg interface{}) error { + if debugMux { + log.Printf("send %d: %#v", c.mux.chanList.offset, msg) + } + + p := Marshal(msg) + binary.BigEndian.PutUint32(p[1:], c.remoteId) + return c.writePacket(p) +} + +// WriteExtended writes data to a specific extended stream. These streams are +// used, for example, for stderr. +func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { + if c.sentEOF { + return 0, io.EOF + } + // 1 byte message type, 4 bytes remoteId, 4 bytes data length + opCode := byte(msgChannelData) + headerLength := uint32(9) + if extendedCode > 0 { + headerLength += 4 + opCode = msgChannelExtendedData + } + + c.writeMu.Lock() + packet := c.packetPool[extendedCode] + // We don't remove the buffer from packetPool, so + // WriteExtended calls from different goroutines will be + // flagged as errors by the race detector. + c.writeMu.Unlock() + + for len(data) > 0 { + space := min(c.maxRemotePayload, len(data)) + if space, err = c.remoteWin.reserve(space); err != nil { + return n, err + } + if want := headerLength + space; uint32(cap(packet)) < want { + packet = make([]byte, want) + } else { + packet = packet[:want] + } + + todo := data[:space] + + packet[0] = opCode + binary.BigEndian.PutUint32(packet[1:], c.remoteId) + if extendedCode > 0 { + binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) + } + binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) + copy(packet[headerLength:], todo) + if err = c.writePacket(packet); err != nil { + return n, err + } + + n += len(todo) + data = data[len(todo):] + } + + c.writeMu.Lock() + c.packetPool[extendedCode] = packet + c.writeMu.Unlock() + + return n, err +} + +func (c *channel) handleData(packet []byte) error { + headerLen := 9 + isExtendedData := packet[0] == msgChannelExtendedData + if isExtendedData { + headerLen = 13 + } + if len(packet) < headerLen { + // malformed data packet + return parseError(packet[0]) + } + + var extended uint32 + if isExtendedData { + extended = binary.BigEndian.Uint32(packet[5:]) + } + + length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) + if length == 0 { + return nil + } + if length > c.maxIncomingPayload { + // TODO(hanwen): should send Disconnect? + return errors.New("ssh: incoming packet exceeds maximum payload size") + } + + data := packet[headerLen:] + if length != uint32(len(data)) { + return errors.New("ssh: wrong packet length") + } + + c.windowMu.Lock() + if c.myWindow < length { + c.windowMu.Unlock() + // TODO(hanwen): should send Disconnect with reason? + return errors.New("ssh: remote side wrote too much") + } + c.myWindow -= length + c.windowMu.Unlock() + + if extended == 1 { + c.extPending.write(data) + } else if extended > 0 { + // discard other extended data. + } else { + c.pending.write(data) + } + return nil +} + +func (c *channel) adjustWindow(n uint32) error { + c.windowMu.Lock() + // Since myWindow is managed on our side, and can never exceed + // the initial window setting, we don't worry about overflow. + c.myWindow += uint32(n) + c.windowMu.Unlock() + return c.sendMessage(windowAdjustMsg{ + AdditionalBytes: uint32(n), + }) +} + +func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { + switch extended { + case 1: + n, err = c.extPending.Read(data) + case 0: + n, err = c.pending.Read(data) + default: + return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) + } + + if n > 0 { + err = c.adjustWindow(uint32(n)) + // sendWindowAdjust can return io.EOF if the remote + // peer has closed the connection, however we want to + // defer forwarding io.EOF to the caller of Read until + // the buffer has been drained. + if n > 0 && err == io.EOF { + err = nil + } + } + + return n, err +} + +func (c *channel) close() { + c.pending.eof() + c.extPending.eof() + close(c.msg) + close(c.incomingRequests) + c.writeMu.Lock() + // This is not necesary for a normal channel teardown, but if + // there was another error, it is. + c.sentClose = true + c.writeMu.Unlock() + // Unblock writers. + c.remoteWin.close() +} + +// responseMessageReceived is called when a success or failure message is +// received on a channel to check that such a message is reasonable for the +// given channel. +func (c *channel) responseMessageReceived() error { + if c.direction == channelInbound { + return errors.New("ssh: channel response message received on inbound channel") + } + if c.decided { + return errors.New("ssh: duplicate response received for channel") + } + c.decided = true + return nil +} + +func (c *channel) handlePacket(packet []byte) error { + switch packet[0] { + case msgChannelData, msgChannelExtendedData: + return c.handleData(packet) + case msgChannelClose: + c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) + c.mux.chanList.remove(c.localId) + c.close() + return nil + case msgChannelEOF: + // RFC 4254 is mute on how EOF affects dataExt messages but + // it is logical to signal EOF at the same time. + c.extPending.eof() + c.pending.eof() + return nil + } + + decoded, err := decode(packet) + if err != nil { + return err + } + + switch msg := decoded.(type) { + case *channelOpenFailureMsg: + if err := c.responseMessageReceived(); err != nil { + return err + } + c.mux.chanList.remove(msg.PeersId) + c.msg <- msg + case *channelOpenConfirmMsg: + if err := c.responseMessageReceived(); err != nil { + return err + } + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) + } + c.remoteId = msg.MyId + c.maxRemotePayload = msg.MaxPacketSize + c.remoteWin.add(msg.MyWindow) + c.msg <- msg + case *windowAdjustMsg: + if !c.remoteWin.add(msg.AdditionalBytes) { + return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) + } + case *channelRequestMsg: + req := Request{ + Type: msg.Request, + WantReply: msg.WantReply, + Payload: msg.RequestSpecificData, + ch: c, + } + + c.incomingRequests <- &req + default: + c.msg <- msg + } + return nil +} + +func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { + ch := &channel{ + remoteWin: window{Cond: newCond()}, + myWindow: channelWindowSize, + pending: newBuffer(), + extPending: newBuffer(), + direction: direction, + incomingRequests: make(chan *Request, 16), + msg: make(chan interface{}, 16), + chanType: chanType, + extraData: extraData, + mux: m, + packetPool: make(map[uint32][]byte), + } + ch.localId = m.chanList.add(ch) + return ch +} + +var errUndecided = errors.New("ssh: must Accept or Reject channel") +var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") + +type extChannel struct { + code uint32 + ch *channel +} + +func (e *extChannel) Write(data []byte) (n int, err error) { + return e.ch.WriteExtended(data, e.code) +} + +func (e *extChannel) Read(data []byte) (n int, err error) { + return e.ch.ReadExtended(data, e.code) +} + +func (c *channel) Accept() (Channel, <-chan *Request, error) { + if c.decided { + return nil, nil, errDecidedAlready + } + c.maxIncomingPayload = channelMaxPacket + confirm := channelOpenConfirmMsg{ + PeersId: c.remoteId, + MyId: c.localId, + MyWindow: c.myWindow, + MaxPacketSize: c.maxIncomingPayload, + } + c.decided = true + if err := c.sendMessage(confirm); err != nil { + return nil, nil, err + } + + return c, c.incomingRequests, nil +} + +func (ch *channel) Reject(reason RejectionReason, message string) error { + if ch.decided { + return errDecidedAlready + } + reject := channelOpenFailureMsg{ + PeersId: ch.remoteId, + Reason: reason, + Message: message, + Language: "en", + } + ch.decided = true + return ch.sendMessage(reject) +} + +func (ch *channel) Read(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.ReadExtended(data, 0) +} + +func (ch *channel) Write(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.WriteExtended(data, 0) +} + +func (ch *channel) CloseWrite() error { + if !ch.decided { + return errUndecided + } + ch.sentEOF = true + return ch.sendMessage(channelEOFMsg{ + PeersId: ch.remoteId}) +} + +func (ch *channel) Close() error { + if !ch.decided { + return errUndecided + } + + return ch.sendMessage(channelCloseMsg{ + PeersId: ch.remoteId}) +} + +// Extended returns an io.ReadWriter that sends and receives data on the given, +// SSH extended stream. Such streams are used, for example, for stderr. +func (ch *channel) Extended(code uint32) io.ReadWriter { + if !ch.decided { + return nil + } + return &extChannel{code, ch} +} + +func (ch *channel) Stderr() io.ReadWriter { + return ch.Extended(1) +} + +func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + if !ch.decided { + return false, errUndecided + } + + if wantReply { + ch.sentRequestMu.Lock() + defer ch.sentRequestMu.Unlock() + } + + msg := channelRequestMsg{ + PeersId: ch.remoteId, + Request: name, + WantReply: wantReply, + RequestSpecificData: payload, + } + + if err := ch.sendMessage(msg); err != nil { + return false, err + } + + if wantReply { + m, ok := (<-ch.msg) + if !ok { + return false, io.EOF + } + switch m.(type) { + case *channelRequestFailureMsg: + return false, nil + case *channelRequestSuccessMsg: + return true, nil + default: + return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) + } + } + + return false, nil +} + +// ackRequest either sends an ack or nack to the channel request. +func (ch *channel) ackRequest(ok bool) error { + if !ch.decided { + return errUndecided + } + + var msg interface{} + if !ok { + msg = channelRequestFailureMsg{ + PeersId: ch.remoteId, + } + } else { + msg = channelRequestSuccessMsg{ + PeersId: ch.remoteId, + } + } + return ch.sendMessage(msg) +} + +func (ch *channel) ChannelType() string { + return ch.chanType +} + +func (ch *channel) ExtraData() []byte { + return ch.extraData +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/cipher.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/cipher.go new file mode 100644 index 000000000..b17318385 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/cipher.go @@ -0,0 +1,523 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rc4" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "hash" + "io" +) + +const ( + packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher. + + // RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations + // MUST be able to process (plus a few more kilobytes for padding and mac). The RFC + // indicates implementations SHOULD be able to handle larger packet sizes, but then + // waffles on about reasonable limits. + // + // OpenSSH caps their maxPacket at 256kB so we choose to do + // the same. maxPacket is also used to ensure that uint32 + // length fields do not overflow, so it should remain well + // below 4G. + maxPacket = 256 * 1024 +) + +// noneCipher implements cipher.Stream and provides no encryption. It is used +// by the transport before the first key-exchange. +type noneCipher struct{} + +func (c noneCipher) XORKeyStream(dst, src []byte) { + copy(dst, src) +} + +func newAESCTR(key, iv []byte) (cipher.Stream, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return cipher.NewCTR(c, iv), nil +} + +func newRC4(key, iv []byte) (cipher.Stream, error) { + return rc4.NewCipher(key) +} + +type streamCipherMode struct { + keySize int + ivSize int + skip int + createFunc func(key, iv []byte) (cipher.Stream, error) +} + +func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) { + if len(key) < c.keySize { + panic("ssh: key length too small for cipher") + } + if len(iv) < c.ivSize { + panic("ssh: iv too small for cipher") + } + + stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize]) + if err != nil { + return nil, err + } + + var streamDump []byte + if c.skip > 0 { + streamDump = make([]byte, 512) + } + + for remainingToDump := c.skip; remainingToDump > 0; { + dumpThisTime := remainingToDump + if dumpThisTime > len(streamDump) { + dumpThisTime = len(streamDump) + } + stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime]) + remainingToDump -= dumpThisTime + } + + return stream, nil +} + +// cipherModes documents properties of supported ciphers. Ciphers not included +// are not supported and will not be negotiated, even if explicitly requested in +// ClientConfig.Crypto.Ciphers. +var cipherModes = map[string]*streamCipherMode{ + // Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms + // are defined in the order specified in the RFC. + "aes128-ctr": {16, aes.BlockSize, 0, newAESCTR}, + "aes192-ctr": {24, aes.BlockSize, 0, newAESCTR}, + "aes256-ctr": {32, aes.BlockSize, 0, newAESCTR}, + + // Ciphers from RFC4345, which introduces security-improved arcfour ciphers. + // They are defined in the order specified in the RFC. + "arcfour128": {16, 0, 1536, newRC4}, + "arcfour256": {32, 0, 1536, newRC4}, + + // Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol. + // Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and + // RC4) has problems with weak keys, and should be used with caution." + // RFC4345 introduces improved versions of Arcfour. + "arcfour": {16, 0, 0, newRC4}, + + // AES-GCM is not a stream cipher, so it is constructed with a + // special case. If we add any more non-stream ciphers, we + // should invest a cleaner way to do this. + gcmCipherID: {16, 12, 0, nil}, + + // insecure cipher, see http://www.isg.rhul.ac.uk/~kp/SandPfinal.pdf + // uncomment below to enable it. + // aes128cbcID: {16, aes.BlockSize, 0, nil}, +} + +// prefixLen is the length of the packet prefix that contains the packet length +// and number of padding bytes. +const prefixLen = 5 + +// streamPacketCipher is a packetCipher using a stream cipher. +type streamPacketCipher struct { + mac hash.Hash + cipher cipher.Stream + + // The following members are to avoid per-packet allocations. + prefix [prefixLen]byte + seqNumBytes [4]byte + padding [2 * packetSizeMultiple]byte + packetData []byte + macResult []byte +} + +// readPacket reads and decrypt a single packet from the reader argument. +func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, s.prefix[:]); err != nil { + return nil, err + } + + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + length := binary.BigEndian.Uint32(s.prefix[0:4]) + paddingLength := uint32(s.prefix[4]) + + var macSize uint32 + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + s.mac.Write(s.prefix[:]) + macSize = uint32(s.mac.Size()) + } + + if length <= paddingLength+1 { + return nil, errors.New("ssh: invalid packet length, packet too small") + } + + if length > maxPacket { + return nil, errors.New("ssh: invalid packet length, packet too large") + } + + // the maxPacket check above ensures that length-1+macSize + // does not overflow. + if uint32(cap(s.packetData)) < length-1+macSize { + s.packetData = make([]byte, length-1+macSize) + } else { + s.packetData = s.packetData[:length-1+macSize] + } + + if _, err := io.ReadFull(r, s.packetData); err != nil { + return nil, err + } + mac := s.packetData[length-1:] + data := s.packetData[:length-1] + s.cipher.XORKeyStream(data, data) + + if s.mac != nil { + s.mac.Write(data) + s.macResult = s.mac.Sum(s.macResult[:0]) + if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { + return nil, errors.New("ssh: MAC failure") + } + } + + return s.packetData[:length-paddingLength-1], nil +} + +// writePacket encrypts and sends a packet of data to the writer argument +func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + if len(packet) > maxPacket { + return errors.New("ssh: packet too large") + } + + paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple + if paddingLength < 4 { + paddingLength += packetSizeMultiple + } + + length := len(packet) + 1 + paddingLength + binary.BigEndian.PutUint32(s.prefix[:], uint32(length)) + s.prefix[4] = byte(paddingLength) + padding := s.padding[:paddingLength] + if _, err := io.ReadFull(rand, padding); err != nil { + return err + } + + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + s.mac.Write(s.prefix[:]) + s.mac.Write(packet) + s.mac.Write(padding) + } + + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + s.cipher.XORKeyStream(packet, packet) + s.cipher.XORKeyStream(padding, padding) + + if _, err := w.Write(s.prefix[:]); err != nil { + return err + } + if _, err := w.Write(packet); err != nil { + return err + } + if _, err := w.Write(padding); err != nil { + return err + } + + if s.mac != nil { + s.macResult = s.mac.Sum(s.macResult[:0]) + if _, err := w.Write(s.macResult); err != nil { + return err + } + } + + return nil +} + +type gcmCipher struct { + aead cipher.AEAD + prefix [4]byte + iv []byte + buf []byte +} + +func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + return &gcmCipher{ + aead: aead, + iv: iv, + }, nil +} + +const gcmTagSize = 16 + +func (c *gcmCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + // Pad out to multiple of 16 bytes. This is different from the + // stream cipher because that encrypts the length too. + padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple) + if padding < 4 { + padding += packetSizeMultiple + } + + length := uint32(len(packet) + int(padding) + 1) + binary.BigEndian.PutUint32(c.prefix[:], length) + if _, err := w.Write(c.prefix[:]); err != nil { + return err + } + + if cap(c.buf) < int(length) { + c.buf = make([]byte, length) + } else { + c.buf = c.buf[:length] + } + + c.buf[0] = padding + copy(c.buf[1:], packet) + if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil { + return err + } + c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if _, err := w.Write(c.buf); err != nil { + return err + } + c.incIV() + + return nil +} + +func (c *gcmCipher) incIV() { + for i := 4 + 7; i >= 4; i-- { + c.iv[i]++ + if c.iv[i] != 0 { + break + } + } +} + +func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, c.prefix[:]); err != nil { + return nil, err + } + length := binary.BigEndian.Uint32(c.prefix[:]) + if length > maxPacket { + return nil, errors.New("ssh: max packet length exceeded.") + } + + if cap(c.buf) < int(length+gcmTagSize) { + c.buf = make([]byte, length+gcmTagSize) + } else { + c.buf = c.buf[:length+gcmTagSize] + } + + if _, err := io.ReadFull(r, c.buf); err != nil { + return nil, err + } + + plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if err != nil { + return nil, err + } + c.incIV() + + padding := plain[0] + if padding < 4 || padding >= 20 { + return nil, fmt.Errorf("ssh: illegal padding %d", padding) + } + + if int(padding+1) >= len(plain) { + return nil, fmt.Errorf("ssh: padding %d too large", padding) + } + plain = plain[1 : length-uint32(padding)] + return plain, nil +} + +// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1 +type cbcCipher struct { + mac hash.Hash + decrypter cipher.BlockMode + encrypter cipher.BlockMode + + // The following members are to avoid per-packet allocations. + seqNumBytes [4]byte + packetData []byte + macResult []byte +} + +func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return &cbcCipher{ + mac: macModes[algs.MAC].new(macKey), + decrypter: cipher.NewCBCDecrypter(c, iv), + encrypter: cipher.NewCBCEncrypter(c, iv), + packetData: make([]byte, 1024), + }, nil +} + +func maxUInt32(a, b int) uint32 { + if a > b { + return uint32(a) + } + return uint32(b) +} + +const ( + cbcMinPacketSizeMultiple = 8 + cbcMinPacketSize = 16 + cbcMinPaddingSize = 4 +) + +func (c *cbcCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { + blockSize := c.decrypter.BlockSize() + + // Read the header, which will include some of the subsequent data in the + // case of block ciphers - this is copied back to the payload later. + // How many bytes of payload/padding will be read with this first read. + firstBlockLength := (prefixLen + blockSize - 1) / blockSize * blockSize + firstBlock := c.packetData[:firstBlockLength] + if _, err := io.ReadFull(r, firstBlock); err != nil { + return nil, err + } + + c.decrypter.CryptBlocks(firstBlock, firstBlock) + length := binary.BigEndian.Uint32(firstBlock[:4]) + if length > maxPacket { + return nil, errors.New("ssh: packet too large") + } + if length+4 < maxUInt32(cbcMinPacketSize, blockSize) { + // The minimum size of a packet is 16 (or the cipher block size, whichever + // is larger) bytes. + return nil, errors.New("ssh: packet too small") + } + // The length of the packet (including the length field but not the MAC) must + // be a multiple of the block size or 8, whichever is larger. + if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 { + return nil, errors.New("ssh: invalid packet length multiple") + } + + paddingLength := uint32(firstBlock[4]) + if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 { + return nil, errors.New("ssh: invalid packet length") + } + + var macSize uint32 + if c.mac != nil { + macSize = uint32(c.mac.Size()) + } + + // Positions within the c.packetData buffer: + macStart := 4 + length + paddingStart := macStart - paddingLength + + // Entire packet size, starting before length, ending at end of mac. + entirePacketSize := macStart + macSize + + // Ensure c.packetData is large enough for the entire packet data. + if uint32(cap(c.packetData)) < entirePacketSize { + // Still need to upsize and copy, but this should be rare at runtime, only + // on upsizing the packetData buffer. + c.packetData = make([]byte, entirePacketSize) + copy(c.packetData, firstBlock) + } else { + c.packetData = c.packetData[:entirePacketSize] + } + + if _, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil { + return nil, err + } + + remainingCrypted := c.packetData[firstBlockLength:macStart] + c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) + + mac := c.packetData[macStart:] + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData[:macStart]) + c.macResult = c.mac.Sum(c.macResult[:0]) + if subtle.ConstantTimeCompare(c.macResult, mac) != 1 { + return nil, errors.New("ssh: MAC failure") + } + } + + return c.packetData[prefixLen:paddingStart], nil +} + +func (c *cbcCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize()) + + // Length of encrypted portion of the packet (header, payload, padding). + // Enforce minimum padding and packet size. + encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize) + // Enforce block size. + encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize + + length := encLength - 4 + paddingLength := int(length) - (1 + len(packet)) + + var macSize uint32 + if c.mac != nil { + macSize = uint32(c.mac.Size()) + } + // Overall buffer contains: header, payload, padding, mac. + // Space for the MAC is reserved in the capacity but not the slice length. + bufferSize := encLength + macSize + if uint32(cap(c.packetData)) < bufferSize { + c.packetData = make([]byte, encLength, bufferSize) + } else { + c.packetData = c.packetData[:encLength] + } + + p := c.packetData + + // Packet header. + binary.BigEndian.PutUint32(p, length) + p = p[4:] + p[0] = byte(paddingLength) + + // Payload. + p = p[1:] + copy(p, packet) + + // Padding. + p = p[len(packet):] + if _, err := io.ReadFull(rand, p); err != nil { + return err + } + + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData) + // The MAC is now appended into the capacity reserved for it earlier. + c.packetData = c.mac.Sum(c.packetData) + } + + c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength]) + + if _, err := w.Write(c.packetData); err != nil { + return err + } + + return nil +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/cipher_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/cipher_test.go new file mode 100644 index 000000000..2fb75d0d7 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/cipher_test.go @@ -0,0 +1,64 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto" + "crypto/aes" + "crypto/rand" + "testing" +) + +func TestDefaultCiphersExist(t *testing.T) { + for _, cipherAlgo := range supportedCiphers { + if _, ok := cipherModes[cipherAlgo]; !ok { + t.Errorf("default cipher %q is unknown", cipherAlgo) + } + } +} + +func TestPacketCiphers(t *testing.T) { + // Still test aes128cbc cipher althought it's commented out. + cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} + defer delete(cipherModes, aes128cbcID) + + for cipher := range cipherModes { + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: cipher, + MAC: "hmac-sha1", + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Errorf("newPacketCipher(client, %q): %v", cipher, err) + continue + } + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Errorf("newPacketCipher(client, %q): %v", cipher, err) + continue + } + + want := "bla bla" + input := []byte(want) + buf := &bytes.Buffer{} + if err := client.writePacket(0, buf, rand.Reader, input); err != nil { + t.Errorf("writePacket(%q): %v", cipher, err) + continue + } + + packet, err := server.readPacket(0, buf) + if err != nil { + t.Errorf("readPacket(%q): %v", cipher, err) + continue + } + + if string(packet) != want { + t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want) + } + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/client.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/client.go new file mode 100644 index 000000000..72bd27f1c --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/client.go @@ -0,0 +1,206 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "fmt" + "net" + "sync" +) + +// Client implements a traditional SSH client that supports shells, +// subprocesses, port forwarding and tunneled dialing. +type Client struct { + Conn + + forwards forwardList // forwarded tcpip connections from the remote side + mu sync.Mutex + channelHandlers map[string]chan NewChannel +} + +// HandleChannelOpen returns a channel on which NewChannel requests +// for the given type are sent. If the type already is being handled, +// nil is returned. The channel is closed when the connection is closed. +func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel { + c.mu.Lock() + defer c.mu.Unlock() + if c.channelHandlers == nil { + // The SSH channel has been closed. + c := make(chan NewChannel) + close(c) + return c + } + + ch := c.channelHandlers[channelType] + if ch != nil { + return nil + } + + ch = make(chan NewChannel, 16) + c.channelHandlers[channelType] = ch + return ch +} + +// NewClient creates a Client on top of the given connection. +func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { + conn := &Client{ + Conn: c, + channelHandlers: make(map[string]chan NewChannel, 1), + } + + go conn.handleGlobalRequests(reqs) + go conn.handleChannelOpens(chans) + go func() { + conn.Wait() + conn.forwards.closeAll() + }() + go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip")) + return conn +} + +// NewClientConn establishes an authenticated SSH connection using c +// as the underlying transport. The Request and NewChannel channels +// must be serviced or the connection will hang. +func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + conn := &connection{ + sshConn: sshConn{conn: c}, + } + + if err := conn.clientHandshake(addr, &fullConf); err != nil { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err) + } + conn.mux = newMux(conn.transport) + return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil +} + +// clientHandshake performs the client side key exchange. See RFC 4253 Section +// 7. +func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { + if config.ClientVersion != "" { + c.clientVersion = []byte(config.ClientVersion) + } else { + c.clientVersion = []byte(packageVersion) + } + var err error + c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) + if err != nil { + return err + } + + c.transport = newClientTransport( + newTransport(c.sshConn.conn, config.Rand, true /* is client */), + c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) + if err := c.transport.requestKeyChange(); err != nil { + return err + } + + if packet, err := c.transport.readPacket(); err != nil { + return err + } else if packet[0] != msgNewKeys { + return unexpectedMessageError(msgNewKeys, packet[0]) + } + + // We just did the key change, so the session ID is established. + c.sessionID = c.transport.getSessionID() + + return c.clientAuthenticate(config) +} + +// verifyHostKeySignature verifies the host key obtained in the key +// exchange. +func verifyHostKeySignature(hostKey PublicKey, result *kexResult) error { + sig, rest, ok := parseSignatureBody(result.Signature) + if len(rest) > 0 || !ok { + return errors.New("ssh: signature parse error") + } + + return hostKey.Verify(result.H, sig) +} + +// NewSession opens a new Session for this client. (A session is a remote +// execution of a program.) +func (c *Client) NewSession() (*Session, error) { + ch, in, err := c.OpenChannel("session", nil) + if err != nil { + return nil, err + } + return newSession(ch, in) +} + +func (c *Client) handleGlobalRequests(incoming <-chan *Request) { + for r := range incoming { + // This handles keepalive messages and matches + // the behaviour of OpenSSH. + r.Reply(false, nil) + } +} + +// handleChannelOpens channel open messages from the remote side. +func (c *Client) handleChannelOpens(in <-chan NewChannel) { + for ch := range in { + c.mu.Lock() + handler := c.channelHandlers[ch.ChannelType()] + c.mu.Unlock() + + if handler != nil { + handler <- ch + } else { + ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType())) + } + } + + c.mu.Lock() + for _, ch := range c.channelHandlers { + close(ch) + } + c.channelHandlers = nil + c.mu.Unlock() +} + +// Dial starts a client connection to the given SSH server. It is a +// convenience function that connects to the given network address, +// initiates the SSH handshake, and then sets up a Client. For access +// to incoming channels and requests, use net.Dial with NewClientConn +// instead. +func Dial(network, addr string, config *ClientConfig) (*Client, error) { + conn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + c, chans, reqs, err := NewClientConn(conn, addr, config) + if err != nil { + return nil, err + } + return NewClient(c, chans, reqs), nil +} + +// A ClientConfig structure is used to configure a Client. It must not be +// modified after having been passed to an SSH function. +type ClientConfig struct { + // Config contains configuration that is shared between clients and + // servers. + Config + + // User contains the username to authenticate as. + User string + + // Auth contains possible authentication methods to use with the + // server. Only the first instance of a particular RFC 4252 method will + // be used during authentication. + Auth []AuthMethod + + // HostKeyCallback, if not nil, is called during the cryptographic + // handshake to validate the server's host key. A nil HostKeyCallback + // implies that all host keys are accepted. + HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error + + // ClientVersion contains the version identification string that will + // be used for the connection. If empty, a reasonable default is used. + ClientVersion string +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/client_auth.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/client_auth.go new file mode 100644 index 000000000..e15be3ef2 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/client_auth.go @@ -0,0 +1,441 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" +) + +// clientAuthenticate authenticates with the remote server. See RFC 4252. +func (c *connection) clientAuthenticate(config *ClientConfig) error { + // initiate user auth session + if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { + return err + } + packet, err := c.transport.readPacket() + if err != nil { + return err + } + var serviceAccept serviceAcceptMsg + if err := Unmarshal(packet, &serviceAccept); err != nil { + return err + } + + // during the authentication phase the client first attempts the "none" method + // then any untried methods suggested by the server. + tried := make(map[string]bool) + var lastMethods []string + for auth := AuthMethod(new(noneAuth)); auth != nil; { + ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand) + if err != nil { + return err + } + if ok { + // success + return nil + } + tried[auth.method()] = true + if methods == nil { + methods = lastMethods + } + lastMethods = methods + + auth = nil + + findNext: + for _, a := range config.Auth { + candidateMethod := a.method() + if tried[candidateMethod] { + continue + } + for _, meth := range methods { + if meth == candidateMethod { + auth = a + break findNext + } + } + } + } + return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried)) +} + +func keys(m map[string]bool) []string { + s := make([]string, 0, len(m)) + + for key := range m { + s = append(s, key) + } + return s +} + +// An AuthMethod represents an instance of an RFC 4252 authentication method. +type AuthMethod interface { + // auth authenticates user over transport t. + // Returns true if authentication is successful. + // If authentication is not successful, a []string of alternative + // method names is returned. If the slice is nil, it will be ignored + // and the previous set of possible methods will be reused. + auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error) + + // method returns the RFC 4252 method name. + method() string +} + +// "none" authentication, RFC 4252 section 5.2. +type noneAuth int + +func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { + if err := c.writePacket(Marshal(&userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: "none", + })); err != nil { + return false, nil, err + } + + return handleAuthResponse(c) +} + +func (n *noneAuth) method() string { + return "none" +} + +// passwordCallback is an AuthMethod that fetches the password through +// a function call, e.g. by prompting the user. +type passwordCallback func() (password string, err error) + +func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { + type passwordAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + Reply bool + Password string + } + + pw, err := cb() + // REVIEW NOTE: is there a need to support skipping a password attempt? + // The program may only find out that the user doesn't have a password + // when prompting. + if err != nil { + return false, nil, err + } + + if err := c.writePacket(Marshal(&passwordAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + Reply: false, + Password: pw, + })); err != nil { + return false, nil, err + } + + return handleAuthResponse(c) +} + +func (cb passwordCallback) method() string { + return "password" +} + +// Password returns an AuthMethod using the given password. +func Password(secret string) AuthMethod { + return passwordCallback(func() (string, error) { return secret, nil }) +} + +// PasswordCallback returns an AuthMethod that uses a callback for +// fetching a password. +func PasswordCallback(prompt func() (secret string, err error)) AuthMethod { + return passwordCallback(prompt) +} + +type publickeyAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + // HasSig indicates to the receiver packet that the auth request is signed and + // should be used for authentication of the request. + HasSig bool + Algoname string + PubKey []byte + // Sig is tagged with "rest" so Marshal will exclude it during + // validateKey + Sig []byte `ssh:"rest"` +} + +// publicKeyCallback is an AuthMethod that uses a set of key +// pairs for authentication. +type publicKeyCallback func() ([]Signer, error) + +func (cb publicKeyCallback) method() string { + return "publickey" +} + +func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { + // Authentication is performed in two stages. The first stage sends an + // enquiry to test if each key is acceptable to the remote. The second + // stage attempts to authenticate with the valid keys obtained in the + // first stage. + + signers, err := cb() + if err != nil { + return false, nil, err + } + var validKeys []Signer + for _, signer := range signers { + if ok, err := validateKey(signer.PublicKey(), user, c); ok { + validKeys = append(validKeys, signer) + } else { + if err != nil { + return false, nil, err + } + } + } + + // methods that may continue if this auth is not successful. + var methods []string + for _, signer := range validKeys { + pub := signer.PublicKey() + + pubKey := pub.Marshal() + sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + }, []byte(pub.Type()), pubKey)) + if err != nil { + return false, nil, err + } + + // manually wrap the serialized signature in a string + s := Marshal(sign) + sig := make([]byte, stringLength(len(s))) + marshalString(sig, s) + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + HasSig: true, + Algoname: pub.Type(), + PubKey: pubKey, + Sig: sig, + } + p := Marshal(&msg) + if err := c.writePacket(p); err != nil { + return false, nil, err + } + var success bool + success, methods, err = handleAuthResponse(c) + if err != nil { + return false, nil, err + } + if success { + return success, methods, err + } + } + return false, methods, nil +} + +// validateKey validates the key provided is acceptable to the server. +func validateKey(key PublicKey, user string, c packetConn) (bool, error) { + pubKey := key.Marshal() + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: "publickey", + HasSig: false, + Algoname: key.Type(), + PubKey: pubKey, + } + if err := c.writePacket(Marshal(&msg)); err != nil { + return false, err + } + + return confirmKeyAck(key, c) +} + +func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { + pubKey := key.Marshal() + algoname := key.Type() + + for { + packet, err := c.readPacket() + if err != nil { + return false, err + } + switch packet[0] { + case msgUserAuthBanner: + // TODO(gpaul): add callback to present the banner to the user + case msgUserAuthPubKeyOk: + var msg userAuthPubKeyOkMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, err + } + if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) { + return false, nil + } + return true, nil + case msgUserAuthFailure: + return false, nil + default: + return false, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + } +} + +// PublicKeys returns an AuthMethod that uses the given key +// pairs. +func PublicKeys(signers ...Signer) AuthMethod { + return publicKeyCallback(func() ([]Signer, error) { return signers, nil }) +} + +// PublicKeysCallback returns an AuthMethod that runs the given +// function to obtain a list of key pairs. +func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod { + return publicKeyCallback(getSigners) +} + +// handleAuthResponse returns whether the preceding authentication request succeeded +// along with a list of remaining authentication methods to try next and +// an error if an unexpected response was received. +func handleAuthResponse(c packetConn) (bool, []string, error) { + for { + packet, err := c.readPacket() + if err != nil { + return false, nil, err + } + + switch packet[0] { + case msgUserAuthBanner: + // TODO: add callback to present the banner to the user + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, nil, err + } + return false, msg.Methods, nil + case msgUserAuthSuccess: + return true, nil, nil + case msgDisconnect: + return false, nil, io.EOF + default: + return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + } +} + +// KeyboardInteractiveChallenge should print questions, optionally +// disabling echoing (e.g. for passwords), and return all the answers. +// Challenge may be called multiple times in a single session. After +// successful authentication, the server may send a challenge with no +// questions, for which the user and instruction messages should be +// printed. RFC 4256 section 3.3 details how the UI should behave for +// both CLI and GUI environments. +type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error) + +// KeyboardInteractive returns a AuthMethod using a prompt/response +// sequence controlled by the server. +func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { + return challenge +} + +func (cb KeyboardInteractiveChallenge) method() string { + return "keyboard-interactive" +} + +func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { + type initiateMsg struct { + User string `sshtype:"50"` + Service string + Method string + Language string + Submethods string + } + + if err := c.writePacket(Marshal(&initiateMsg{ + User: user, + Service: serviceSSH, + Method: "keyboard-interactive", + })); err != nil { + return false, nil, err + } + + for { + packet, err := c.readPacket() + if err != nil { + return false, nil, err + } + + // like handleAuthResponse, but with less options. + switch packet[0] { + case msgUserAuthBanner: + // TODO: Print banners during userauth. + continue + case msgUserAuthInfoRequest: + // OK + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, nil, err + } + return false, msg.Methods, nil + case msgUserAuthSuccess: + return true, nil, nil + default: + return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } + + var msg userAuthInfoRequestMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, nil, err + } + + // Manually unpack the prompt/echo pairs. + rest := msg.Prompts + var prompts []string + var echos []bool + for i := 0; i < int(msg.NumPrompts); i++ { + prompt, r, ok := parseString(rest) + if !ok || len(r) == 0 { + return false, nil, errors.New("ssh: prompt format error") + } + prompts = append(prompts, string(prompt)) + echos = append(echos, r[0] != 0) + rest = r[1:] + } + + if len(rest) != 0 { + return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs") + } + + answers, err := cb(msg.User, msg.Instruction, prompts, echos) + if err != nil { + return false, nil, err + } + + if len(answers) != len(prompts) { + return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback") + } + responseLength := 1 + 4 + for _, a := range answers { + responseLength += stringLength(len(a)) + } + serialized := make([]byte, responseLength) + p := serialized + p[0] = msgUserAuthInfoResponse + p = p[1:] + p = marshalUint32(p, uint32(len(answers))) + for _, a := range answers { + p = marshalString(p, []byte(a)) + } + + if err := c.writePacket(serialized); err != nil { + return false, nil, err + } + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/client_auth_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/client_auth_test.go new file mode 100644 index 000000000..c92b58786 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/client_auth_test.go @@ -0,0 +1,393 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "strings" + "testing" +) + +type keyboardInteractive map[string]string + +func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) { + var answers []string + for _, q := range questions { + answers = append(answers, cr[q]) + } + return answers, nil +} + +// reused internally by tests +var clientPassword = "tiger" + +// tryAuth runs a handshake with a given config against an SSH server +// with config serverConfig +func tryAuth(t *testing.T, config *ClientConfig) error { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + certChecker := CertChecker{ + IsAuthority: func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) + }, + UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + return nil, nil + } + + return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) + }, + IsRevoked: func(c *Certificate) bool { + return c.Serial == 666 + }, + } + + serverConfig := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { + if conn.User() == "testuser" && string(pass) == clientPassword { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + PublicKeyCallback: certChecker.Authenticate, + KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) { + ans, err := challenge("user", + "instruction", + []string{"question1", "question2"}, + []bool{true, true}) + if err != nil { + return nil, err + } + ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2" + if ok { + challenge("user", "motd", nil, nil) + return nil, nil + } + return nil, errors.New("keyboard-interactive failed") + }, + AuthLogCallback: func(conn ConnMetadata, method string, err error) { + t.Logf("user %q, method %q: %v", conn.User(), method, err) + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + go newServer(c1, serverConfig) + _, _, _, err = NewClientConn(c2, "", config) + return err +} + +func TestClientAuthPublicKey(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodPassword(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + Password(clientPassword), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodFallback(t *testing.T) { + var passwordCalled bool + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + PasswordCallback( + func() (string, error) { + passwordCalled = true + return "WRONG", nil + }), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + + if passwordCalled { + t.Errorf("password auth tried before public-key auth.") + } +} + +func TestAuthMethodWrongPassword(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + Password("wrong"), + PublicKeys(testSigners["rsa"]), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodKeyboardInteractive(t *testing.T) { + answers := keyboardInteractive(map[string]string{ + "question1": "answer1", + "question2": "answer2", + }) + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodWrongKeyboardInteractive(t *testing.T) { + answers := keyboardInteractive(map[string]string{ + "question1": "answer1", + "question2": "WRONG", + }) + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), + }, + } + + if err := tryAuth(t, config); err == nil { + t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive") + } +} + +// the mock server will only authenticate ssh-rsa keys +func TestAuthMethodInvalidPublicKey(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"]), + }, + } + + if err := tryAuth(t, config); err == nil { + t.Fatalf("dsa private key should not have authenticated with rsa public key") + } +} + +// the client should authenticate with the second key +func TestAuthMethodRSAandDSA(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"], testSigners["rsa"]), + }, + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("client could not authenticate with rsa key: %v", err) + } +} + +func TestClientHMAC(t *testing.T) { + for _, mac := range supportedMACs { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + Config: Config{ + MACs: []string{mac}, + }, + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) + } + } +} + +// issue 4285. +func TestClientUnsupportedCipher(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(), + }, + Config: Config{ + Ciphers: []string{"aes128-cbc"}, // not currently supported + }, + } + if err := tryAuth(t, config); err == nil { + t.Errorf("expected no ciphers in common") + } +} + +func TestClientUnsupportedKex(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(), + }, + Config: Config{ + KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported + }, + } + if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "no common algorithms") { + t.Errorf("got %v, expected 'no common algorithms'", err) + } +} + +func TestClientLoginCert(t *testing.T) { + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + clientConfig := &ClientConfig{ + User: "user", + } + clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) + + t.Log("should succeed") + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + t.Log("corrupted signature") + cert.Signature.Blob[0]++ + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with corrupted sig") + } + + t.Log("revoked") + cert.Serial = 666 + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("revoked cert login succeeded") + } + cert.Serial = 1 + + t.Log("sign with wrong key") + cert.SignCert(rand.Reader, testSigners["dsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with non-authoritive key") + } + + t.Log("host cert") + cert.CertType = HostCert + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong type") + } + cert.CertType = UserCert + + t.Log("principal specified") + cert.ValidPrincipals = []string{"user"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + t.Log("wrong principal specified") + cert.ValidPrincipals = []string{"fred"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong principal") + } + cert.ValidPrincipals = nil + + t.Log("added critical option") + cert.CriticalOptions = map[string]string{"root-access": "yes"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with unrecognized critical option") + } + + t.Log("allowed source address") + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login with source-address failed: %v", err) + } + + t.Log("disallowed source address") + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login with source-address succeeded") + } +} + +func testPermissionsPassing(withPermissions bool, t *testing.T) { + serverConfig := &ServerConfig{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "nopermissions" { + return nil, nil + } else { + return &Permissions{}, nil + } + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + clientConfig := &ClientConfig{ + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + } + if withPermissions { + clientConfig.User = "permissions" + } else { + clientConfig.User = "nopermissions" + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + serverConn, err := newServer(c1, serverConfig) + if err != nil { + t.Fatal(err) + } + if p := serverConn.Permissions; (p != nil) != withPermissions { + t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p) + } +} + +func TestPermissionsPassing(t *testing.T) { + testPermissionsPassing(true, t) +} + +func TestNoPermissionsPassing(t *testing.T) { + testPermissionsPassing(false, t) +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/client_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/client_test.go new file mode 100644 index 000000000..1fe790cb4 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/client_test.go @@ -0,0 +1,39 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "net" + "testing" +) + +func testClientVersion(t *testing.T, config *ClientConfig, expected string) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + receivedVersion := make(chan string, 1) + go func() { + version, err := readVersion(serverConn) + if err != nil { + receivedVersion <- "" + } else { + receivedVersion <- string(version) + } + serverConn.Close() + }() + NewClientConn(clientConn, "", config) + actual := <-receivedVersion + if actual != expected { + t.Fatalf("got %s; want %s", actual, expected) + } +} + +func TestCustomClientVersion(t *testing.T) { + version := "Test-Client-Version-0.0" + testClientVersion(t, &ClientConfig{ClientVersion: version}, version) +} + +func TestDefaultClientVersion(t *testing.T) { + testClientVersion(t, &ClientConfig{}, packageVersion) +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/common.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/common.go new file mode 100644 index 000000000..b03a3dfd8 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/common.go @@ -0,0 +1,365 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto" + "crypto/rand" + "fmt" + "io" + "sync" + + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" +) + +// These are string constants in the SSH protocol. +const ( + compressionNone = "none" + serviceUserAuth = "ssh-userauth" + serviceSSH = "ssh-connection" +) + +// supportedCiphers specifies the supported ciphers in preference order. +var supportedCiphers = []string{ + "aes128-ctr", "aes192-ctr", "aes256-ctr", + "aes128-gcm@openssh.com", + "arcfour256", "arcfour128", +} + +// supportedKexAlgos specifies the supported key-exchange algorithms in +// preference order. +var supportedKexAlgos = []string{ + // P384 and P521 are not constant-time yet, but since we don't + // reuse ephemeral keys, using them for ECDH should be OK. + kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, + kexAlgoDH14SHA1, kexAlgoDH1SHA1, +} + +// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods +// of authenticating servers) in preference order. +var supportedHostKeyAlgos = []string{ + CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, + CertAlgoECDSA384v01, CertAlgoECDSA521v01, + + KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, + KeyAlgoRSA, KeyAlgoDSA, +} + +// supportedMACs specifies a default set of MAC algorithms in preference order. +// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed +// because they have reached the end of their useful life. +var supportedMACs = []string{ + "hmac-sha1", "hmac-sha1-96", +} + +var supportedCompressions = []string{compressionNone} + +// hashFuncs keeps the mapping of supported algorithms to their respective +// hashes needed for signature verification. +var hashFuncs = map[string]crypto.Hash{ + KeyAlgoRSA: crypto.SHA1, + KeyAlgoDSA: crypto.SHA1, + KeyAlgoECDSA256: crypto.SHA256, + KeyAlgoECDSA384: crypto.SHA384, + KeyAlgoECDSA521: crypto.SHA512, + CertAlgoRSAv01: crypto.SHA1, + CertAlgoDSAv01: crypto.SHA1, + CertAlgoECDSA256v01: crypto.SHA256, + CertAlgoECDSA384v01: crypto.SHA384, + CertAlgoECDSA521v01: crypto.SHA512, +} + +// unexpectedMessageError results when the SSH message that we received didn't +// match what we wanted. +func unexpectedMessageError(expected, got uint8) error { + return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected) +} + +// parseError results from a malformed SSH message. +func parseError(tag uint8) error { + return fmt.Errorf("ssh: parse error in message type %d", tag) +} + +func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) { + for _, clientAlgo := range clientAlgos { + for _, serverAlgo := range serverAlgos { + if clientAlgo == serverAlgo { + return clientAlgo, true + } + } + } + return +} + +func findCommonCipher(clientCiphers []string, serverCiphers []string) (commonCipher string, ok bool) { + for _, clientCipher := range clientCiphers { + for _, serverCipher := range serverCiphers { + // reject the cipher if we have no cipherModes definition + if clientCipher == serverCipher && cipherModes[clientCipher] != nil { + return clientCipher, true + } + } + } + return +} + +type directionAlgorithms struct { + Cipher string + MAC string + Compression string +} + +type algorithms struct { + kex string + hostKey string + w directionAlgorithms + r directionAlgorithms +} + +func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms) { + var ok bool + result := &algorithms{} + result.kex, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos) + if !ok { + return + } + + result.hostKey, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) + if !ok { + return + } + + result.w.Cipher, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) + if !ok { + return + } + + result.r.Cipher, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) + if !ok { + return + } + + result.w.MAC, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) + if !ok { + return + } + + result.r.MAC, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) + if !ok { + return + } + + result.w.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) + if !ok { + return + } + + result.r.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) + if !ok { + return + } + + return result +} + +// If rekeythreshold is too small, we can't make any progress sending +// stuff. +const minRekeyThreshold uint64 = 256 + +// Config contains configuration data common to both ServerConfig and +// ClientConfig. +type Config struct { + // Rand provides the source of entropy for cryptographic + // primitives. If Rand is nil, the cryptographic random reader + // in package crypto/rand will be used. + Rand io.Reader + + // The maximum number of bytes sent or received after which a + // new key is negotiated. It must be at least 256. If + // unspecified, 1 gigabyte is used. + RekeyThreshold uint64 + + // The allowed key exchanges algorithms. If unspecified then a + // default set of algorithms is used. + KeyExchanges []string + + // The allowed cipher algorithms. If unspecified then a sensible + // default is used. + Ciphers []string + + // The allowed MAC algorithms. If unspecified then a sensible default + // is used. + MACs []string +} + +// SetDefaults sets sensible values for unset fields in config. This is +// exported for testing: Configs passed to SSH functions are copied and have +// default values set automatically. +func (c *Config) SetDefaults() { + if c.Rand == nil { + c.Rand = rand.Reader + } + if c.Ciphers == nil { + c.Ciphers = supportedCiphers + } + var ciphers []string + for _, c := range c.Ciphers { + if cipherModes[c] != nil { + // reject the cipher if we have no cipherModes definition + ciphers = append(ciphers, c) + } + } + c.Ciphers = ciphers + + if c.KeyExchanges == nil { + c.KeyExchanges = supportedKexAlgos + } + + if c.MACs == nil { + c.MACs = supportedMACs + } + + if c.RekeyThreshold == 0 { + // RFC 4253, section 9 suggests rekeying after 1G. + c.RekeyThreshold = 1 << 30 + } + if c.RekeyThreshold < minRekeyThreshold { + c.RekeyThreshold = minRekeyThreshold + } +} + +// buildDataSignedForAuth returns the data that is signed in order to prove +// possession of a private key. See RFC 4252, section 7. +func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { + data := struct { + Session []byte + Type byte + User string + Service string + Method string + Sign bool + Algo []byte + PubKey []byte + }{ + sessionId, + msgUserAuthRequest, + req.User, + req.Service, + req.Method, + true, + algo, + pubKey, + } + return Marshal(data) +} + +func appendU16(buf []byte, n uint16) []byte { + return append(buf, byte(n>>8), byte(n)) +} + +func appendU32(buf []byte, n uint32) []byte { + return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendU64(buf []byte, n uint64) []byte { + return append(buf, + byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), + byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendInt(buf []byte, n int) []byte { + return appendU32(buf, uint32(n)) +} + +func appendString(buf []byte, s string) []byte { + buf = appendU32(buf, uint32(len(s))) + buf = append(buf, s...) + return buf +} + +func appendBool(buf []byte, b bool) []byte { + if b { + return append(buf, 1) + } + return append(buf, 0) +} + +// newCond is a helper to hide the fact that there is no usable zero +// value for sync.Cond. +func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) } + +// window represents the buffer available to clients +// wishing to write to a channel. +type window struct { + *sync.Cond + win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 + writeWaiters int + closed bool +} + +// add adds win to the amount of window available +// for consumers. +func (w *window) add(win uint32) bool { + // a zero sized window adjust is a noop. + if win == 0 { + return true + } + w.L.Lock() + if w.win+win < win { + w.L.Unlock() + return false + } + w.win += win + // It is unusual that multiple goroutines would be attempting to reserve + // window space, but not guaranteed. Use broadcast to notify all waiters + // that additional window is available. + w.Broadcast() + w.L.Unlock() + return true +} + +// close sets the window to closed, so all reservations fail +// immediately. +func (w *window) close() { + w.L.Lock() + w.closed = true + w.Broadcast() + w.L.Unlock() +} + +// reserve reserves win from the available window capacity. +// If no capacity remains, reserve will block. reserve may +// return less than requested. +func (w *window) reserve(win uint32) (uint32, error) { + var err error + w.L.Lock() + w.writeWaiters++ + w.Broadcast() + for w.win == 0 && !w.closed { + w.Wait() + } + w.writeWaiters-- + if w.win < win { + win = w.win + } + w.win -= win + if w.closed { + err = io.EOF + } + w.L.Unlock() + return win, err +} + +// waitWriterBlocked waits until some goroutine is blocked for further +// writes. It is used in tests only. +func (w *window) waitWriterBlocked() { + w.Cond.L.Lock() + for w.writeWaiters == 0 { + w.Cond.Wait() + } + w.Cond.L.Unlock() +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/connection.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/connection.go new file mode 100644 index 000000000..93551e241 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/connection.go @@ -0,0 +1,144 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "fmt" + "net" +) + +// OpenChannelError is returned if the other side rejects an +// OpenChannel request. +type OpenChannelError struct { + Reason RejectionReason + Message string +} + +func (e *OpenChannelError) Error() string { + return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) +} + +// ConnMetadata holds metadata for the connection. +type ConnMetadata interface { + // User returns the user ID for this connection. + // It is empty if no authentication is used. + User() string + + // SessionID returns the sesson hash, also denoted by H. + SessionID() []byte + + // ClientVersion returns the client's version string as hashed + // into the session ID. + ClientVersion() []byte + + // ServerVersion returns the client's version string as hashed + // into the session ID. + ServerVersion() []byte + + // RemoteAddr returns the remote address for this connection. + RemoteAddr() net.Addr + + // LocalAddr returns the local address for this connection. + LocalAddr() net.Addr +} + +// Conn represents an SSH connection for both server and client roles. +// Conn is the basis for implementing an application layer, such +// as ClientConn, which implements the traditional shell access for +// clients. +type Conn interface { + ConnMetadata + + // SendRequest sends a global request, and returns the + // reply. If wantReply is true, it returns the response status + // and payload. See also RFC4254, section 4. + SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) + + // OpenChannel tries to open an channel. If the request is + // rejected, it returns *OpenChannelError. On success it returns + // the SSH Channel and a Go channel for incoming, out-of-band + // requests. The Go channel must be serviced, or the + // connection will hang. + OpenChannel(name string, data []byte) (Channel, <-chan *Request, error) + + // Close closes the underlying network connection + Close() error + + // Wait blocks until the connection has shut down, and returns the + // error causing the shutdown. + Wait() error + + // TODO(hanwen): consider exposing: + // RequestKeyChange + // Disconnect +} + +// DiscardRequests consumes and rejects all requests from the +// passed-in channel. +func DiscardRequests(in <-chan *Request) { + for req := range in { + if req.WantReply { + req.Reply(false, nil) + } + } +} + +// A connection represents an incoming connection. +type connection struct { + transport *handshakeTransport + sshConn + + // The connection protocol. + *mux +} + +func (c *connection) Close() error { + return c.sshConn.conn.Close() +} + +// sshconn provides net.Conn metadata, but disallows direct reads and +// writes. +type sshConn struct { + conn net.Conn + + user string + sessionID []byte + clientVersion []byte + serverVersion []byte +} + +func dup(src []byte) []byte { + dst := make([]byte, len(src)) + copy(dst, src) + return dst +} + +func (c *sshConn) User() string { + return c.user +} + +func (c *sshConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *sshConn) Close() error { + return c.conn.Close() +} + +func (c *sshConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *sshConn) SessionID() []byte { + return dup(c.sessionID) +} + +func (c *sshConn) ClientVersion() []byte { + return dup(c.clientVersion) +} + +func (c *sshConn) ServerVersion() []byte { + return dup(c.serverVersion) +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/doc.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/doc.go new file mode 100644 index 000000000..d6be89466 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/doc.go @@ -0,0 +1,18 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package ssh implements an SSH client and server. + +SSH is a transport security protocol, an authentication protocol and a +family of application protocols. The most typical application level +protocol is a remote shell and this is specifically implemented. However, +the multiplexed nature of SSH is exposed to users that wish to support +others. + +References: + [PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD + [SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 +*/ +package ssh // import "golang.org/x/crypto/ssh" diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/example_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/example_test.go new file mode 100644 index 000000000..dfd9dcab6 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/example_test.go @@ -0,0 +1,211 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh_test + +import ( + "bytes" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" +) + +func ExampleNewServerConn() { + // An SSH server is represented by a ServerConfig, which holds + // certificate details and handles authentication of ServerConns. + config := &ssh.ServerConfig{ + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + } + + privateBytes, err := ioutil.ReadFile("id_rsa") + if err != nil { + panic("Failed to load private key") + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + panic("Failed to parse private key") + } + + config.AddHostKey(private) + + // Once a ServerConfig has been configured, connections can be + // accepted. + listener, err := net.Listen("tcp", "0.0.0.0:2022") + if err != nil { + panic("failed to listen for connection") + } + nConn, err := listener.Accept() + if err != nil { + panic("failed to accept incoming connection") + } + + // Before use, a handshake must be performed on the incoming + // net.Conn. + _, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + panic("failed to handshake") + } + // The incoming Request channel must be serviced. + go ssh.DiscardRequests(reqs) + + // Service the incoming Channel channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of a shell, the type is + // "session" and ServerShell may be used to present a simple + // terminal interface. + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + panic("could not accept channel.") + } + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "shell" request. + go func(in <-chan *ssh.Request) { + for req := range in { + ok := false + switch req.Type { + case "shell": + ok = true + if len(req.Payload) > 0 { + // We don't accept any + // commands, only the + // default shell. + ok = false + } + } + req.Reply(ok, nil) + } + }(requests) + + term := terminal.NewTerminal(channel, "> ") + + go func() { + defer channel.Close() + for { + line, err := term.ReadLine() + if err != nil { + break + } + fmt.Println(line) + } + }() + } +} + +func ExampleDial() { + // An SSH client is represented with a ClientConn. Currently only + // the "password" authentication method is supported. + // + // To authenticate with the remote server you must pass at least one + // implementation of AuthMethod via the Auth field in ClientConfig. + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("yourpassword"), + }, + } + client, err := ssh.Dial("tcp", "yourserver.com:22", config) + if err != nil { + panic("Failed to dial: " + err.Error()) + } + + // Each ClientConn can support multiple interactive sessions, + // represented by a Session. + session, err := client.NewSession() + if err != nil { + panic("Failed to create session: " + err.Error()) + } + defer session.Close() + + // Once a Session is created, you can execute a single command on + // the remote side using the Run method. + var b bytes.Buffer + session.Stdout = &b + if err := session.Run("/usr/bin/whoami"); err != nil { + panic("Failed to run: " + err.Error()) + } + fmt.Println(b.String()) +} + +func ExampleClient_Listen() { + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("password"), + }, + } + // Dial your ssh server. + conn, err := ssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatalf("unable to connect: %s", err) + } + defer conn.Close() + + // Request the remote side to open port 8080 on all interfaces. + l, err := conn.Listen("tcp", "0.0.0.0:8080") + if err != nil { + log.Fatalf("unable to register tcp forward: %v", err) + } + defer l.Close() + + // Serve HTTP with your SSH server acting as a reverse proxy. + http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + fmt.Fprintf(resp, "Hello world!\n") + })) +} + +func ExampleSession_RequestPty() { + // Create client config + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("password"), + }, + } + // Connect to ssh server + conn, err := ssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatalf("unable to connect: %s", err) + } + defer conn.Close() + // Create a session + session, err := conn.NewSession() + if err != nil { + log.Fatalf("unable to create session: %s", err) + } + defer session.Close() + // Set up terminal modes + modes := ssh.TerminalModes{ + ssh.ECHO: 0, // disable echoing + ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud + ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud + } + // Request pseudo terminal + if err := session.RequestPty("xterm", 80, 40, modes); err != nil { + log.Fatalf("request for pseudo terminal failed: %s", err) + } + // Start remote shell + if err := session.Shell(); err != nil { + log.Fatalf("failed to start shell: %s", err) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/handshake.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/handshake.go new file mode 100644 index 000000000..a1e2c23da --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/handshake.go @@ -0,0 +1,393 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + "log" + "net" + "sync" +) + +// debugHandshake, if set, prints messages sent and received. Key +// exchange messages are printed as if DH were used, so the debug +// messages are wrong when using ECDH. +const debugHandshake = false + +// keyingTransport is a packet based transport that supports key +// changes. It need not be thread-safe. It should pass through +// msgNewKeys in both directions. +type keyingTransport interface { + packetConn + + // prepareKeyChange sets up a key change. The key change for a + // direction will be effected if a msgNewKeys message is sent + // or received. + prepareKeyChange(*algorithms, *kexResult) error + + // getSessionID returns the session ID. prepareKeyChange must + // have been called once. + getSessionID() []byte +} + +// rekeyingTransport is the interface of handshakeTransport that we +// (internally) expose to ClientConn and ServerConn. +type rekeyingTransport interface { + packetConn + + // requestKeyChange asks the remote side to change keys. All + // writes are blocked until the key change succeeds, which is + // signaled by reading a msgNewKeys. + requestKeyChange() error + + // getSessionID returns the session ID. This is only valid + // after the first key change has completed. + getSessionID() []byte +} + +// handshakeTransport implements rekeying on top of a keyingTransport +// and offers a thread-safe writePacket() interface. +type handshakeTransport struct { + conn keyingTransport + config *Config + + serverVersion []byte + clientVersion []byte + + hostKeys []Signer // If hostKeys are given, we are the server. + + // On read error, incoming is closed, and readError is set. + incoming chan []byte + readError error + + // data for host key checking + hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error + dialAddress string + remoteAddr net.Addr + + readSinceKex uint64 + + // Protects the writing side of the connection + mu sync.Mutex + cond *sync.Cond + sentInitPacket []byte + sentInitMsg *kexInitMsg + writtenSinceKex uint64 + writeError error +} + +func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { + t := &handshakeTransport{ + conn: conn, + serverVersion: serverVersion, + clientVersion: clientVersion, + incoming: make(chan []byte, 16), + config: config, + } + t.cond = sync.NewCond(&t.mu) + return t +} + +func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.dialAddress = dialAddr + t.remoteAddr = addr + t.hostKeyCallback = config.HostKeyCallback + go t.readLoop() + return t +} + +func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.hostKeys = config.hostKeys + go t.readLoop() + return t +} + +func (t *handshakeTransport) getSessionID() []byte { + return t.conn.getSessionID() +} + +func (t *handshakeTransport) id() string { + if len(t.hostKeys) > 0 { + return "server" + } + return "client" +} + +func (t *handshakeTransport) readPacket() ([]byte, error) { + p, ok := <-t.incoming + if !ok { + return nil, t.readError + } + return p, nil +} + +func (t *handshakeTransport) readLoop() { + for { + p, err := t.readOnePacket() + if err != nil { + t.readError = err + close(t.incoming) + break + } + if p[0] == msgIgnore || p[0] == msgDebug { + continue + } + t.incoming <- p + } +} + +func (t *handshakeTransport) readOnePacket() ([]byte, error) { + if t.readSinceKex > t.config.RekeyThreshold { + if err := t.requestKeyChange(); err != nil { + return nil, err + } + } + + p, err := t.conn.readPacket() + if err != nil { + return nil, err + } + + t.readSinceKex += uint64(len(p)) + if debugHandshake { + msg, err := decode(p) + log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err) + } + if p[0] != msgKexInit { + return p, nil + } + err = t.enterKeyExchange(p) + + t.mu.Lock() + if err != nil { + // drop connection + t.conn.Close() + t.writeError = err + } + + if debugHandshake { + log.Printf("%s exited key exchange, err %v", t.id(), err) + } + + // Unblock writers. + t.sentInitMsg = nil + t.sentInitPacket = nil + t.cond.Broadcast() + t.writtenSinceKex = 0 + t.mu.Unlock() + + if err != nil { + return nil, err + } + + t.readSinceKex = 0 + return []byte{msgNewKeys}, nil +} + +// sendKexInit sends a key change message, and returns the message +// that was sent. After initiating the key change, all writes will be +// blocked until the change is done, and a failed key change will +// close the underlying transport. This function is safe for +// concurrent use by multiple goroutines. +func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) { + t.mu.Lock() + defer t.mu.Unlock() + return t.sendKexInitLocked() +} + +func (t *handshakeTransport) requestKeyChange() error { + _, _, err := t.sendKexInit() + return err +} + +// sendKexInitLocked sends a key change message. t.mu must be locked +// while this happens. +func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) { + // kexInits may be sent either in response to the other side, + // or because our side wants to initiate a key change, so we + // may have already sent a kexInit. In that case, don't send a + // second kexInit. + if t.sentInitMsg != nil { + return t.sentInitMsg, t.sentInitPacket, nil + } + msg := &kexInitMsg{ + KexAlgos: t.config.KeyExchanges, + CiphersClientServer: t.config.Ciphers, + CiphersServerClient: t.config.Ciphers, + MACsClientServer: t.config.MACs, + MACsServerClient: t.config.MACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + } + io.ReadFull(rand.Reader, msg.Cookie[:]) + + if len(t.hostKeys) > 0 { + for _, k := range t.hostKeys { + msg.ServerHostKeyAlgos = append( + msg.ServerHostKeyAlgos, k.PublicKey().Type()) + } + } else { + msg.ServerHostKeyAlgos = supportedHostKeyAlgos + } + packet := Marshal(msg) + + // writePacket destroys the contents, so save a copy. + packetCopy := make([]byte, len(packet)) + copy(packetCopy, packet) + + if err := t.conn.writePacket(packetCopy); err != nil { + return nil, nil, err + } + + t.sentInitMsg = msg + t.sentInitPacket = packet + return msg, packet, nil +} + +func (t *handshakeTransport) writePacket(p []byte) error { + t.mu.Lock() + if t.writtenSinceKex > t.config.RekeyThreshold { + t.sendKexInitLocked() + } + for t.sentInitMsg != nil { + t.cond.Wait() + } + if t.writeError != nil { + return t.writeError + } + t.writtenSinceKex += uint64(len(p)) + + var err error + switch p[0] { + case msgKexInit: + err = errors.New("ssh: only handshakeTransport can send kexInit") + case msgNewKeys: + err = errors.New("ssh: only handshakeTransport can send newKeys") + default: + err = t.conn.writePacket(p) + } + t.mu.Unlock() + return err +} + +func (t *handshakeTransport) Close() error { + return t.conn.Close() +} + +// enterKeyExchange runs the key exchange. +func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { + if debugHandshake { + log.Printf("%s entered key exchange", t.id()) + } + myInit, myInitPacket, err := t.sendKexInit() + if err != nil { + return err + } + + otherInit := &kexInitMsg{} + if err := Unmarshal(otherInitPacket, otherInit); err != nil { + return err + } + + magics := handshakeMagics{ + clientVersion: t.clientVersion, + serverVersion: t.serverVersion, + clientKexInit: otherInitPacket, + serverKexInit: myInitPacket, + } + + clientInit := otherInit + serverInit := myInit + if len(t.hostKeys) == 0 { + clientInit = myInit + serverInit = otherInit + + magics.clientKexInit = myInitPacket + magics.serverKexInit = otherInitPacket + } + + algs := findAgreedAlgorithms(clientInit, serverInit) + if algs == nil { + return errors.New("ssh: no common algorithms") + } + + // We don't send FirstKexFollows, but we handle receiving it. + if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] { + // other side sent a kex message for the wrong algorithm, + // which we have to ignore. + if _, err := t.conn.readPacket(); err != nil { + return err + } + } + + kex, ok := kexAlgoMap[algs.kex] + if !ok { + return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) + } + + var result *kexResult + if len(t.hostKeys) > 0 { + result, err = t.server(kex, algs, &magics) + } else { + result, err = t.client(kex, algs, &magics) + } + + if err != nil { + return err + } + + t.conn.prepareKeyChange(algs, result) + if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { + return err + } + if packet, err := t.conn.readPacket(); err != nil { + return err + } else if packet[0] != msgNewKeys { + return unexpectedMessageError(msgNewKeys, packet[0]) + } + return nil +} + +func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { + var hostKey Signer + for _, k := range t.hostKeys { + if algs.hostKey == k.PublicKey().Type() { + hostKey = k + } + } + + r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey) + return r, err +} + +func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { + result, err := kex.Client(t.conn, t.config.Rand, magics) + if err != nil { + return nil, err + } + + hostKey, err := ParsePublicKey(result.HostKey) + if err != nil { + return nil, err + } + + if err := verifyHostKeySignature(hostKey, result); err != nil { + return nil, err + } + + if t.hostKeyCallback != nil { + err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) + if err != nil { + return nil, err + } + } + + return result, nil +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/handshake_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/handshake_test.go new file mode 100644 index 000000000..613c49822 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/handshake_test.go @@ -0,0 +1,311 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "fmt" + "net" + "testing" +) + +type testChecker struct { + calls []string +} + +func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + if dialAddr == "bad" { + return fmt.Errorf("dialAddr is bad") + } + + if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { + return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) + } + + t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) + + return nil +} + +// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and +// therefore is buffered (net.Pipe deadlocks if both sides start with +// a write.) +func netPipe() (net.Conn, net.Conn, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer listener.Close() + c1, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1, c2, nil +} + +func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) { + a, b, err := netPipe() + if err != nil { + return nil, nil, err + } + + trC := newTransport(a, rand.Reader, true) + trS := newTransport(b, rand.Reader, false) + clientConf.SetDefaults() + + v := []byte("version") + client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.SetDefaults() + server = newServerTransport(trS, v, v, serverConf) + + return client, server, nil +} + +func TestHandshakeBasic(t *testing.T) { + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + go func() { + // Client writes a bunch of stuff, and does a key + // change in the middle. This should not confuse the + // handshake in progress + for i := 0; i < 10; i++ { + p := []byte{msgRequestSuccess, byte(i)} + if err := trC.writePacket(p); err != nil { + t.Fatalf("sendPacket: %v", err) + } + if i == 5 { + // halfway through, we request a key change. + _, _, err := trC.sendKexInit() + if err != nil { + t.Fatalf("sendKexInit: %v", err) + } + } + } + trC.Close() + }() + + // Server checks that client messages come in cleanly + i := 0 + for { + p, err := trS.readPacket() + if err != nil { + break + } + if p[0] == msgNewKeys { + continue + } + want := []byte{msgRequestSuccess, byte(i)} + if bytes.Compare(p, want) != 0 { + t.Errorf("message %d: got %q, want %q", i, p, want) + } + i++ + } + if i != 10 { + t.Errorf("received %d messages, want 10.", i) + } + + // If all went well, we registered exactly 1 key change. + if len(checker.calls) != 1 { + t.Fatalf("got %d host key checks, want 1", len(checker.calls)) + } + + pub := testSigners["ecdsa"].PublicKey() + want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal()) + if want != checker.calls[0] { + t.Errorf("got %q want %q for host key check", checker.calls[0], want) + } +} + +func TestHandshakeError(t *testing.T) { + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + // send a packet + packet := []byte{msgRequestSuccess, 42} + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + + // Now request a key change. + _, _, err = trC.sendKexInit() + if err != nil { + t.Errorf("sendKexInit: %v", err) + } + + // the key change will fail, and afterwards we can't write. + if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil { + t.Errorf("writePacket after botched rekey succeeded.") + } + + readback, err := trS.readPacket() + if err != nil { + t.Fatalf("server closed too soon: %v", err) + } + if bytes.Compare(readback, packet) != 0 { + t.Errorf("got %q want %q", readback, packet) + } + readback, err = trS.readPacket() + if err == nil { + t.Errorf("got a message %q after failed key change", readback) + } +} + +func TestHandshakeTwice(t *testing.T) { + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + // send a packet + packet := make([]byte, 5) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + + // Now request a key change. + _, _, err = trC.sendKexInit() + if err != nil { + t.Errorf("sendKexInit: %v", err) + } + + // Send another packet. Use a fresh one, since writePacket destroys. + packet = make([]byte, 5) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + + // 2nd key change. + _, _, err = trC.sendKexInit() + if err != nil { + t.Errorf("sendKexInit: %v", err) + } + + packet = make([]byte, 5) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + + packet = make([]byte, 5) + packet[0] = msgRequestSuccess + for i := 0; i < 5; i++ { + msg, err := trS.readPacket() + if err != nil { + t.Fatalf("server closed too soon: %v", err) + } + if msg[0] == msgNewKeys { + continue + } + + if bytes.Compare(msg, packet) != 0 { + t.Errorf("packet %d: got %q want %q", i, msg, packet) + } + } + if len(checker.calls) != 2 { + t.Errorf("got %d key changes, want 2", len(checker.calls)) + } +} + +func TestHandshakeAutoRekeyWrite(t *testing.T) { + checker := &testChecker{} + clientConf := &ClientConfig{HostKeyCallback: checker.Check} + clientConf.RekeyThreshold = 500 + trC, trS, err := handshakePair(clientConf, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + for i := 0; i < 5; i++ { + packet := make([]byte, 251) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + } + + j := 0 + for ; j < 5; j++ { + _, err := trS.readPacket() + if err != nil { + break + } + } + + if j != 5 { + t.Errorf("got %d, want 5 messages", j) + } + + if len(checker.calls) != 2 { + t.Errorf("got %d key changes, wanted 2", len(checker.calls)) + } +} + +type syncChecker struct { + called chan int +} + +func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + t.called <- 1 + return nil +} + +func TestHandshakeAutoRekeyRead(t *testing.T) { + sync := &syncChecker{make(chan int, 2)} + clientConf := &ClientConfig{ + HostKeyCallback: sync.Check, + } + clientConf.RekeyThreshold = 500 + + trC, trS, err := handshakePair(clientConf, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + packet := make([]byte, 501) + packet[0] = msgRequestSuccess + if err := trS.writePacket(packet); err != nil { + t.Fatalf("writePacket: %v", err) + } + // While we read out the packet, a key change will be + // initiated. + if _, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket(client): %v", err) + } + + <-sync.called +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/kex.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/kex.go new file mode 100644 index 000000000..6a835c763 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/kex.go @@ -0,0 +1,386 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "errors" + "io" + "math/big" +) + +const ( + kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1" + kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" + kexAlgoECDH256 = "ecdh-sha2-nistp256" + kexAlgoECDH384 = "ecdh-sha2-nistp384" + kexAlgoECDH521 = "ecdh-sha2-nistp521" +) + +// kexResult captures the outcome of a key exchange. +type kexResult struct { + // Session hash. See also RFC 4253, section 8. + H []byte + + // Shared secret. See also RFC 4253, section 8. + K []byte + + // Host key as hashed into H. + HostKey []byte + + // Signature of H. + Signature []byte + + // A cryptographic hash function that matches the security + // level of the key exchange algorithm. It is used for + // calculating H, and for deriving keys from H and K. + Hash crypto.Hash + + // The session ID, which is the first H computed. This is used + // to signal data inside transport. + SessionID []byte +} + +// handshakeMagics contains data that is always included in the +// session hash. +type handshakeMagics struct { + clientVersion, serverVersion []byte + clientKexInit, serverKexInit []byte +} + +func (m *handshakeMagics) write(w io.Writer) { + writeString(w, m.clientVersion) + writeString(w, m.serverVersion) + writeString(w, m.clientKexInit) + writeString(w, m.serverKexInit) +} + +// kexAlgorithm abstracts different key exchange algorithms. +type kexAlgorithm interface { + // Server runs server-side key agreement, signing the result + // with a hostkey. + Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error) + + // Client runs the client-side key agreement. Caller is + // responsible for verifying the host key signature. + Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) +} + +// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement. +type dhGroup struct { + g, p *big.Int +} + +func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { + if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 { + return nil, errors.New("ssh: DH parameter out of bounds") + } + return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil +} + +func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { + hashFunc := crypto.SHA1 + + x, err := rand.Int(randSource, group.p) + if err != nil { + return nil, err + } + X := new(big.Int).Exp(group.g, x, group.p) + kexDHInit := kexDHInitMsg{ + X: X, + } + if err := c.writePacket(Marshal(&kexDHInit)); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexDHReply kexDHReplyMsg + if err = Unmarshal(packet, &kexDHReply); err != nil { + return nil, err + } + + kInt, err := group.diffieHellman(kexDHReply.Y, x) + if err != nil { + return nil, err + } + + h := hashFunc.New() + magics.write(h) + writeString(h, kexDHReply.HostKey) + writeInt(h, X) + writeInt(h, kexDHReply.Y) + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: kexDHReply.HostKey, + Signature: kexDHReply.Signature, + Hash: crypto.SHA1, + }, nil +} + +func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { + hashFunc := crypto.SHA1 + packet, err := c.readPacket() + if err != nil { + return + } + var kexDHInit kexDHInitMsg + if err = Unmarshal(packet, &kexDHInit); err != nil { + return + } + + y, err := rand.Int(randSource, group.p) + if err != nil { + return + } + + Y := new(big.Int).Exp(group.g, y, group.p) + kInt, err := group.diffieHellman(kexDHInit.X, y) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := hashFunc.New() + magics.write(h) + writeString(h, hostKeyBytes) + writeInt(h, kexDHInit.X) + writeInt(h, Y) + + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, randSource, H) + if err != nil { + return nil, err + } + + kexDHReply := kexDHReplyMsg{ + HostKey: hostKeyBytes, + Y: Y, + Signature: sig, + } + packet = Marshal(&kexDHReply) + + err = c.writePacket(packet) + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: crypto.SHA1, + }, nil +} + +// ecdh performs Elliptic Curve Diffie-Hellman key exchange as +// described in RFC 5656, section 4. +type ecdh struct { + curve elliptic.Curve +} + +func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + kexInit := kexECDHInitMsg{ + ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), + } + + serialized := Marshal(&kexInit) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var reply kexECDHReplyMsg + if err = Unmarshal(packet, &reply); err != nil { + return nil, err + } + + x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey) + if err != nil { + return nil, err + } + + // generate shared secret + secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, reply.HostKey) + writeString(h, kexInit.ClientPubKey) + writeString(h, reply.EphemeralPubKey) + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: reply.HostKey, + Signature: reply.Signature, + Hash: ecHash(kex.curve), + }, nil +} + +// unmarshalECKey parses and checks an EC key. +func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) { + x, y = elliptic.Unmarshal(curve, pubkey) + if x == nil { + return nil, nil, errors.New("ssh: elliptic.Unmarshal failure") + } + if !validateECPublicKey(curve, x, y) { + return nil, nil, errors.New("ssh: public key not on curve") + } + return x, y, nil +} + +// validateECPublicKey checks that the point is a valid public key for +// the given curve. See [SEC1], 3.2.2 +func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool { + if x.Sign() == 0 && y.Sign() == 0 { + return false + } + + if x.Cmp(curve.Params().P) >= 0 { + return false + } + + if y.Cmp(curve.Params().P) >= 0 { + return false + } + + if !curve.IsOnCurve(x, y) { + return false + } + + // We don't check if N * PubKey == 0, since + // + // - the NIST curves have cofactor = 1, so this is implicit. + // (We don't foresee an implementation that supports non NIST + // curves) + // + // - for ephemeral keys, we don't need to worry about small + // subgroup attacks. + return true +} + +func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexECDHInit kexECDHInitMsg + if err = Unmarshal(packet, &kexECDHInit); err != nil { + return nil, err + } + + clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey) + if err != nil { + return nil, err + } + + // We could cache this key across multiple users/multiple + // connection attempts, but the benefit is small. OpenSSH + // generates a new key for each incoming connection. + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y) + + // generate shared secret + secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, hostKeyBytes) + writeString(h, kexECDHInit.ClientPubKey) + writeString(h, serializedEphKey) + + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, rand, H) + if err != nil { + return nil, err + } + + reply := kexECDHReplyMsg{ + EphemeralPubKey: serializedEphKey, + HostKey: hostKeyBytes, + Signature: sig, + } + + serialized := Marshal(&reply) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + return &kexResult{ + H: H, + K: K, + HostKey: reply.HostKey, + Signature: sig, + Hash: ecHash(kex.curve), + }, nil +} + +var kexAlgoMap = map[string]kexAlgorithm{} + +func init() { + // This is the group called diffie-hellman-group1-sha1 in RFC + // 4253 and Oakley Group 2 in RFC 2409. + p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) + kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + } + + // This is the group called diffie-hellman-group14-sha1 in RFC + // 4253 and Oakley Group 14 in RFC 3526. + p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + + kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + } + + kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()} + kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()} + kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()} +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/kex_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/kex_test.go new file mode 100644 index 000000000..0db5f9be1 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/kex_test.go @@ -0,0 +1,48 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Key exchange tests. + +import ( + "crypto/rand" + "reflect" + "testing" +) + +func TestKexes(t *testing.T) { + type kexResultErr struct { + result *kexResult + err error + } + + for name, kex := range kexAlgoMap { + a, b := memPipe() + + s := make(chan kexResultErr, 1) + c := make(chan kexResultErr, 1) + var magics handshakeMagics + go func() { + r, e := kex.Client(a, rand.Reader, &magics) + c <- kexResultErr{r, e} + }() + go func() { + r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"]) + s <- kexResultErr{r, e} + }() + + clientRes := <-c + serverRes := <-s + if clientRes.err != nil { + t.Errorf("client: %v", clientRes.err) + } + if serverRes.err != nil { + t.Errorf("server: %v", serverRes.err) + } + if !reflect.DeepEqual(clientRes.result, serverRes.result) { + t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result) + } + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/keys.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/keys.go new file mode 100644 index 000000000..e8af511ee --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/keys.go @@ -0,0 +1,628 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto" + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "encoding/asn1" + "encoding/base64" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" +) + +// These constants represent the algorithm names for key types supported by this +// package. +const ( + KeyAlgoRSA = "ssh-rsa" + KeyAlgoDSA = "ssh-dss" + KeyAlgoECDSA256 = "ecdsa-sha2-nistp256" + KeyAlgoECDSA384 = "ecdsa-sha2-nistp384" + KeyAlgoECDSA521 = "ecdsa-sha2-nistp521" +) + +// parsePubKey parses a public key of the given algorithm. +// Use ParsePublicKey for keys with prepended algorithm. +func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { + switch algo { + case KeyAlgoRSA: + return parseRSA(in) + case KeyAlgoDSA: + return parseDSA(in) + case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: + return parseECDSA(in) + case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: + cert, err := parseCert(in, certToPrivAlgo(algo)) + if err != nil { + return nil, nil, err + } + return cert, nil, nil + } + return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", err) +} + +// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format +// (see sshd(8) manual page) once the options and key type fields have been +// removed. +func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) { + in = bytes.TrimSpace(in) + + i := bytes.IndexAny(in, " \t") + if i == -1 { + i = len(in) + } + base64Key := in[:i] + + key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key))) + n, err := base64.StdEncoding.Decode(key, base64Key) + if err != nil { + return nil, "", err + } + key = key[:n] + out, err = ParsePublicKey(key) + if err != nil { + return nil, "", err + } + comment = string(bytes.TrimSpace(in[i:])) + return out, comment, nil +} + +// ParseAuthorizedKeys parses a public key from an authorized_keys +// file used in OpenSSH according to the sshd(8) manual page. +func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { + for len(in) > 0 { + end := bytes.IndexByte(in, '\n') + if end != -1 { + rest = in[end+1:] + in = in[:end] + } else { + rest = nil + } + + end = bytes.IndexByte(in, '\r') + if end != -1 { + in = in[:end] + } + + in = bytes.TrimSpace(in) + if len(in) == 0 || in[0] == '#' { + in = rest + continue + } + + i := bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + return out, comment, options, rest, nil + } + + // No key type recognised. Maybe there's an options field at + // the beginning. + var b byte + inQuote := false + var candidateOptions []string + optionStart := 0 + for i, b = range in { + isEnd := !inQuote && (b == ' ' || b == '\t') + if (b == ',' && !inQuote) || isEnd { + if i-optionStart > 0 { + candidateOptions = append(candidateOptions, string(in[optionStart:i])) + } + optionStart = i + 1 + } + if isEnd { + break + } + if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) { + inQuote = !inQuote + } + } + for i < len(in) && (in[i] == ' ' || in[i] == '\t') { + i++ + } + if i == len(in) { + // Invalid line: unmatched quote + in = rest + continue + } + + in = in[i:] + i = bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + options = candidateOptions + return out, comment, options, rest, nil + } + + in = rest + continue + } + + return nil, "", nil, nil, errors.New("ssh: no key found") +} + +// ParsePublicKey parses an SSH public key formatted for use in +// the SSH wire protocol according to RFC 4253, section 6.6. +func ParsePublicKey(in []byte) (out PublicKey, err error) { + algo, in, ok := parseString(in) + if !ok { + return nil, errShortRead + } + var rest []byte + out, rest, err = parsePubKey(in, string(algo)) + if len(rest) > 0 { + return nil, errors.New("ssh: trailing junk in public key") + } + + return out, err +} + +// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH +// authorized_keys file. The return value ends with newline. +func MarshalAuthorizedKey(key PublicKey) []byte { + b := &bytes.Buffer{} + b.WriteString(key.Type()) + b.WriteByte(' ') + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(key.Marshal()) + e.Close() + b.WriteByte('\n') + return b.Bytes() +} + +// PublicKey is an abstraction of different types of public keys. +type PublicKey interface { + // Type returns the key's type, e.g. "ssh-rsa". + Type() string + + // Marshal returns the serialized key data in SSH wire format, + // with the name prefix. + Marshal() []byte + + // Verify that sig is a signature on the given data using this + // key. This function will hash the data appropriately first. + Verify(data []byte, sig *Signature) error +} + +// A Signer can create signatures that verify against a public key. +type Signer interface { + // PublicKey returns an associated PublicKey instance. + PublicKey() PublicKey + + // Sign returns raw signature for the given data. This method + // will apply the hash specified for the keytype to the data. + Sign(rand io.Reader, data []byte) (*Signature, error) +} + +type rsaPublicKey rsa.PublicKey + +func (r *rsaPublicKey) Type() string { + return "ssh-rsa" +} + +// parseRSA parses an RSA key according to RFC 4253, section 6.6. +func parseRSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + E *big.Int + N *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if w.E.BitLen() > 24 { + return nil, nil, errors.New("ssh: exponent too large") + } + e := w.E.Int64() + if e < 3 || e&1 == 0 { + return nil, nil, errors.New("ssh: incorrect exponent") + } + + var key rsa.PublicKey + key.E = int(e) + key.N = w.N + return (*rsaPublicKey)(&key), w.Rest, nil +} + +func (r *rsaPublicKey) Marshal() []byte { + e := new(big.Int).SetInt64(int64(r.E)) + wirekey := struct { + Name string + E *big.Int + N *big.Int + }{ + KeyAlgoRSA, + e, + r.N, + } + return Marshal(&wirekey) +} + +func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != r.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) + } + h := crypto.SHA1.New() + h.Write(data) + digest := h.Sum(nil) + return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob) +} + +type rsaPrivateKey struct { + *rsa.PrivateKey +} + +func (r *rsaPrivateKey) PublicKey() PublicKey { + return (*rsaPublicKey)(&r.PrivateKey.PublicKey) +} + +func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { + h := crypto.SHA1.New() + h.Write(data) + digest := h.Sum(nil) + blob, err := rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest) + if err != nil { + return nil, err + } + return &Signature{ + Format: r.PublicKey().Type(), + Blob: blob, + }, nil +} + +type dsaPublicKey dsa.PublicKey + +func (r *dsaPublicKey) Type() string { + return "ssh-dss" +} + +// parseDSA parses an DSA key according to RFC 4253, section 6.6. +func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + P, Q, G, Y *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := &dsaPublicKey{ + Parameters: dsa.Parameters{ + P: w.P, + Q: w.Q, + G: w.G, + }, + Y: w.Y, + } + return key, w.Rest, nil +} + +func (k *dsaPublicKey) Marshal() []byte { + w := struct { + Name string + P, Q, G, Y *big.Int + }{ + k.Type(), + k.P, + k.Q, + k.G, + k.Y, + } + + return Marshal(&w) +} + +func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + h := crypto.SHA1.New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 4253, section 6.6, + // The value for 'dss_signature_blob' is encoded as a string containing + // r, followed by s (which are 160-bit integers, without lengths or + // padding, unsigned, and in network byte order). + // For DSS purposes, sig.Blob should be exactly 40 bytes in length. + if len(sig.Blob) != 40 { + return errors.New("ssh: DSA signature parse error") + } + r := new(big.Int).SetBytes(sig.Blob[:20]) + s := new(big.Int).SetBytes(sig.Blob[20:]) + if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +type dsaPrivateKey struct { + *dsa.PrivateKey +} + +func (k *dsaPrivateKey) PublicKey() PublicKey { + return (*dsaPublicKey)(&k.PrivateKey.PublicKey) +} + +func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { + h := crypto.SHA1.New() + h.Write(data) + digest := h.Sum(nil) + r, s, err := dsa.Sign(rand, k.PrivateKey, digest) + if err != nil { + return nil, err + } + + sig := make([]byte, 40) + rb := r.Bytes() + sb := s.Bytes() + + copy(sig[20-len(rb):20], rb) + copy(sig[40-len(sb):], sb) + + return &Signature{ + Format: k.PublicKey().Type(), + Blob: sig, + }, nil +} + +type ecdsaPublicKey ecdsa.PublicKey + +func (key *ecdsaPublicKey) Type() string { + return "ecdsa-sha2-" + key.nistID() +} + +func (key *ecdsaPublicKey) nistID() string { + switch key.Params().BitSize { + case 256: + return "nistp256" + case 384: + return "nistp384" + case 521: + return "nistp521" + } + panic("ssh: unsupported ecdsa key size") +} + +func supportedEllipticCurve(curve elliptic.Curve) bool { + return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521() +} + +// ecHash returns the hash to match the given elliptic curve, see RFC +// 5656, section 6.2.1 +func ecHash(curve elliptic.Curve) crypto.Hash { + bitSize := curve.Params().BitSize + switch { + case bitSize <= 256: + return crypto.SHA256 + case bitSize <= 384: + return crypto.SHA384 + } + return crypto.SHA512 +} + +// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1. +func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) { + identifier, in, ok := parseString(in) + if !ok { + return nil, nil, errShortRead + } + + key := new(ecdsa.PublicKey) + + switch string(identifier) { + case "nistp256": + key.Curve = elliptic.P256() + case "nistp384": + key.Curve = elliptic.P384() + case "nistp521": + key.Curve = elliptic.P521() + default: + return nil, nil, errors.New("ssh: unsupported curve") + } + + var keyBytes []byte + if keyBytes, in, ok = parseString(in); !ok { + return nil, nil, errShortRead + } + + key.X, key.Y = elliptic.Unmarshal(key.Curve, keyBytes) + if key.X == nil || key.Y == nil { + return nil, nil, errors.New("ssh: invalid curve point") + } + return (*ecdsaPublicKey)(key), in, nil +} + +func (key *ecdsaPublicKey) Marshal() []byte { + // See RFC 5656, section 3.1. + keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y) + w := struct { + Name string + ID string + Key []byte + }{ + key.Type(), + key.nistID(), + keyBytes, + } + + return Marshal(&w) +} + +func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != key.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) + } + + h := ecHash(key.Curve).New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 5656, section 3.1.2, + // The ecdsa_signature_blob value has the following specific encoding: + // mpint r + // mpint s + var ecSig struct { + R *big.Int + S *big.Int + } + + if err := Unmarshal(sig.Blob, &ecSig); err != nil { + return err + } + + if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +type ecdsaPrivateKey struct { + *ecdsa.PrivateKey +} + +func (k *ecdsaPrivateKey) PublicKey() PublicKey { + return (*ecdsaPublicKey)(&k.PrivateKey.PublicKey) +} + +func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { + h := ecHash(k.PrivateKey.PublicKey.Curve).New() + h.Write(data) + digest := h.Sum(nil) + r, s, err := ecdsa.Sign(rand, k.PrivateKey, digest) + if err != nil { + return nil, err + } + + sig := make([]byte, intLength(r)+intLength(s)) + rest := marshalInt(sig, r) + marshalInt(rest, s) + return &Signature{ + Format: k.PublicKey().Type(), + Blob: sig, + }, nil +} + +// NewSignerFromKey takes a pointer to rsa, dsa or ecdsa PrivateKey +// returns a corresponding Signer instance. EC keys should use P256, +// P384 or P521. +func NewSignerFromKey(k interface{}) (Signer, error) { + var sshKey Signer + switch t := k.(type) { + case *rsa.PrivateKey: + sshKey = &rsaPrivateKey{t} + case *dsa.PrivateKey: + sshKey = &dsaPrivateKey{t} + case *ecdsa.PrivateKey: + if !supportedEllipticCurve(t.Curve) { + return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.") + } + + sshKey = &ecdsaPrivateKey{t} + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", k) + } + return sshKey, nil +} + +// NewPublicKey takes a pointer to rsa, dsa or ecdsa PublicKey +// and returns a corresponding ssh PublicKey instance. EC keys should use P256, P384 or P521. +func NewPublicKey(k interface{}) (PublicKey, error) { + var sshKey PublicKey + switch t := k.(type) { + case *rsa.PublicKey: + sshKey = (*rsaPublicKey)(t) + case *ecdsa.PublicKey: + if !supportedEllipticCurve(t.Curve) { + return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.") + } + sshKey = (*ecdsaPublicKey)(t) + case *dsa.PublicKey: + sshKey = (*dsaPublicKey)(t) + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", k) + } + return sshKey, nil +} + +// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports +// the same keys as ParseRawPrivateKey. +func ParsePrivateKey(pemBytes []byte) (Signer, error) { + key, err := ParseRawPrivateKey(pemBytes) + if err != nil { + return nil, err + } + + return NewSignerFromKey(key) +} + +// ParseRawPrivateKey returns a private key from a PEM encoded private key. It +// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys. +func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("ssh: no key found") + } + + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(block.Bytes) + case "DSA PRIVATE KEY": + return ParseDSAPrivateKey(block.Bytes) + default: + return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) + } +} + +// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as +// specified by the OpenSSL DSA man page. +func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { + var k struct { + Version int + P *big.Int + Q *big.Int + G *big.Int + Priv *big.Int + Pub *big.Int + } + rest, err := asn1.Unmarshal(der, &k) + if err != nil { + return nil, errors.New("ssh: failed to parse DSA key: " + err.Error()) + } + if len(rest) > 0 { + return nil, errors.New("ssh: garbage after DSA key") + } + + return &dsa.PrivateKey{ + PublicKey: dsa.PublicKey{ + Parameters: dsa.Parameters{ + P: k.P, + Q: k.Q, + G: k.G, + }, + Y: k.Priv, + }, + X: k.Pub, + }, nil +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/keys_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/keys_test.go new file mode 100644 index 000000000..36b97ad22 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/keys_test.go @@ -0,0 +1,306 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "fmt" + "reflect" + "strings" + "testing" + + "golang.org/x/crypto/ssh/testdata" +) + +func rawKey(pub PublicKey) interface{} { + switch k := pub.(type) { + case *rsaPublicKey: + return (*rsa.PublicKey)(k) + case *dsaPublicKey: + return (*dsa.PublicKey)(k) + case *ecdsaPublicKey: + return (*ecdsa.PublicKey)(k) + case *Certificate: + return k + } + panic("unknown key type") +} + +func TestKeyMarshalParse(t *testing.T) { + for _, priv := range testSigners { + pub := priv.PublicKey() + roundtrip, err := ParsePublicKey(pub.Marshal()) + if err != nil { + t.Errorf("ParsePublicKey(%T): %v", pub, err) + } + + k1 := rawKey(pub) + k2 := rawKey(roundtrip) + + if !reflect.DeepEqual(k1, k2) { + t.Errorf("got %#v in roundtrip, want %#v", k2, k1) + } + } +} + +func TestUnsupportedCurves(t *testing.T) { + raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + + if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P256") { + t.Fatalf("NewPrivateKey should not succeed with P224, got: %v", err) + } + + if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P256") { + t.Fatalf("NewPublicKey should not succeed with P224, got: %v", err) + } +} + +func TestNewPublicKey(t *testing.T) { + for _, k := range testSigners { + raw := rawKey(k.PublicKey()) + // Skip certificates, as NewPublicKey does not support them. + if _, ok := raw.(*Certificate); ok { + continue + } + pub, err := NewPublicKey(raw) + if err != nil { + t.Errorf("NewPublicKey(%#v): %v", raw, err) + } + if !reflect.DeepEqual(k.PublicKey(), pub) { + t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey()) + } + } +} + +func TestKeySignVerify(t *testing.T) { + for _, priv := range testSigners { + pub := priv.PublicKey() + + data := []byte("sign me") + sig, err := priv.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("Sign(%T): %v", priv, err) + } + + if err := pub.Verify(data, sig); err != nil { + t.Errorf("publicKey.Verify(%T): %v", priv, err) + } + sig.Blob[5]++ + if err := pub.Verify(data, sig); err == nil { + t.Errorf("publicKey.Verify on broken sig did not fail") + } + } +} + +func TestParseRSAPrivateKey(t *testing.T) { + key := testPrivateKeys["rsa"] + + rsa, ok := key.(*rsa.PrivateKey) + if !ok { + t.Fatalf("got %T, want *rsa.PrivateKey", rsa) + } + + if err := rsa.Validate(); err != nil { + t.Errorf("Validate: %v", err) + } +} + +func TestParseECPrivateKey(t *testing.T) { + key := testPrivateKeys["ecdsa"] + + ecKey, ok := key.(*ecdsa.PrivateKey) + if !ok { + t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey) + } + + if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) { + t.Fatalf("public key does not validate.") + } +} + +func TestParseDSA(t *testing.T) { + // We actually exercise the ParsePrivateKey codepath here, as opposed to + // using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go + // uses. + s, err := ParsePrivateKey(testdata.PEMBytes["dsa"]) + if err != nil { + t.Fatalf("ParsePrivateKey returned error: %s", err) + } + + data := []byte("sign me") + sig, err := s.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("dsa.Sign: %v", err) + } + + if err := s.PublicKey().Verify(data, sig); err != nil { + t.Errorf("Verify failed: %v", err) + } +} + +// Tests for authorized_keys parsing. + +// getTestKey returns a public key, and its base64 encoding. +func getTestKey() (PublicKey, string) { + k := testPublicKeys["rsa"] + + b := &bytes.Buffer{} + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(k.Marshal()) + e.Close() + + return k, b.String() +} + +func TestMarshalParsePublicKey(t *testing.T) { + pub, pubSerialized := getTestKey() + line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized) + + authKeys := MarshalAuthorizedKey(pub) + actualFields := strings.Fields(string(authKeys)) + if len(actualFields) == 0 { + t.Fatalf("failed authKeys: %v", authKeys) + } + + // drop the comment + expectedFields := strings.Fields(line)[0:2] + + if !reflect.DeepEqual(actualFields, expectedFields) { + t.Errorf("got %v, expected %v", actualFields, expectedFields) + } + + actPub, _, _, _, err := ParseAuthorizedKey([]byte(line)) + if err != nil { + t.Fatalf("cannot parse %v: %v", line, err) + } + if !reflect.DeepEqual(actPub, pub) { + t.Errorf("got %v, expected %v", actPub, pub) + } +} + +type authResult struct { + pubKey PublicKey + options []string + comments string + rest string + ok bool +} + +func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) { + rest := authKeys + var values []authResult + for len(rest) > 0 { + var r authResult + var err error + r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest) + r.ok = (err == nil) + t.Log(err) + r.rest = string(rest) + values = append(values, r) + } + + if !reflect.DeepEqual(values, expected) { + t.Errorf("got %#v, expected %#v", values, expected) + } +} + +func TestAuthorizedKeyBasic(t *testing.T) { + pub, pubSerialized := getTestKey() + line := "ssh-rsa " + pubSerialized + " user@host" + testAuthorizedKeys(t, []byte(line), + []authResult{ + {pub, nil, "user@host", "", true}, + }) +} + +func TestAuth(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithOptions := []string{ + `# comments to ignore before any keys...`, + ``, + `env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`, + `# comments to ignore, along with a blank line`, + ``, + `env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`, + ``, + `# more comments, plus a invalid entry`, + `ssh-rsa data-that-will-not-parse user@host3`, + } + for _, eol := range []string{"\n", "\r\n"} { + authOptions := strings.Join(authWithOptions, eol) + rest2 := strings.Join(authWithOptions[3:], eol) + rest3 := strings.Join(authWithOptions[6:], eol) + testAuthorizedKeys(t, []byte(authOptions), []authResult{ + {pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true}, + {pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true}, + {nil, nil, "", "", false}, + }) + } +} + +func TestAuthWithQuotedSpaceInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{ + {pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedCommaInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{ + {pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedQuoteInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`) + authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{ + {pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) + + testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{ + {pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true}, + }) +} + +func TestAuthWithInvalidSpace(t *testing.T) { + _, pubSerialized := getTestKey() + authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +#more to follow but still no valid keys`) + testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{ + {nil, nil, "", "", false}, + }) +} + +func TestAuthWithMissingQuote(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`) + + testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{ + {pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true}, + }) +} + +func TestInvalidEntry(t *testing.T) { + authInvalid := []byte(`ssh-rsa`) + _, _, _, _, err := ParseAuthorizedKey(authInvalid) + if err == nil { + t.Errorf("got valid entry for %q", authInvalid) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/mac.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/mac.go new file mode 100644 index 000000000..aff404291 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/mac.go @@ -0,0 +1,53 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Message authentication support + +import ( + "crypto/hmac" + "crypto/sha1" + "hash" +) + +type macMode struct { + keySize int + new func(key []byte) hash.Hash +} + +// truncatingMAC wraps around a hash.Hash and truncates the output digest to +// a given size. +type truncatingMAC struct { + length int + hmac hash.Hash +} + +func (t truncatingMAC) Write(data []byte) (int, error) { + return t.hmac.Write(data) +} + +func (t truncatingMAC) Sum(in []byte) []byte { + out := t.hmac.Sum(in) + return out[:len(in)+t.length] +} + +func (t truncatingMAC) Reset() { + t.hmac.Reset() +} + +func (t truncatingMAC) Size() int { + return t.length +} + +func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } + +var macModes = map[string]*macMode{ + "hmac-sha1": {20, func(key []byte) hash.Hash { + return hmac.New(sha1.New, key) + }}, + "hmac-sha1-96": {20, func(key []byte) hash.Hash { + return truncatingMAC{12, hmac.New(sha1.New, key)} + }}, +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/mempipe_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/mempipe_test.go new file mode 100644 index 000000000..92519dd6b --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/mempipe_test.go @@ -0,0 +1,110 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "sync" + "testing" +) + +// An in-memory packetConn. It is safe to call Close and writePacket +// from different goroutines. +type memTransport struct { + eof bool + pending [][]byte + write *memTransport + sync.Mutex + *sync.Cond +} + +func (t *memTransport) readPacket() ([]byte, error) { + t.Lock() + defer t.Unlock() + for { + if len(t.pending) > 0 { + r := t.pending[0] + t.pending = t.pending[1:] + return r, nil + } + if t.eof { + return nil, io.EOF + } + t.Cond.Wait() + } +} + +func (t *memTransport) closeSelf() error { + t.Lock() + defer t.Unlock() + if t.eof { + return io.EOF + } + t.eof = true + t.Cond.Broadcast() + return nil +} + +func (t *memTransport) Close() error { + err := t.write.closeSelf() + t.closeSelf() + return err +} + +func (t *memTransport) writePacket(p []byte) error { + t.write.Lock() + defer t.write.Unlock() + if t.write.eof { + return io.EOF + } + c := make([]byte, len(p)) + copy(c, p) + t.write.pending = append(t.write.pending, c) + t.write.Cond.Signal() + return nil +} + +func memPipe() (a, b packetConn) { + t1 := memTransport{} + t2 := memTransport{} + t1.write = &t2 + t2.write = &t1 + t1.Cond = sync.NewCond(&t1.Mutex) + t2.Cond = sync.NewCond(&t2.Mutex) + return &t1, &t2 +} + +func TestmemPipe(t *testing.T) { + a, b := memPipe() + if err := a.writePacket([]byte{42}); err != nil { + t.Fatalf("writePacket: %v", err) + } + if err := a.Close(); err != nil { + t.Fatal("Close: ", err) + } + p, err := b.readPacket() + if err != nil { + t.Fatal("readPacket: ", err) + } + if len(p) != 1 || p[0] != 42 { + t.Fatalf("got %v, want {42}", p) + } + p, err = b.readPacket() + if err != io.EOF { + t.Fatalf("got %v, %v, want EOF", p, err) + } +} + +func TestDoubleClose(t *testing.T) { + a, _ := memPipe() + err := a.Close() + if err != nil { + t.Errorf("Close: %v", err) + } + err = a.Close() + if err != io.EOF { + t.Errorf("expect EOF on double close.") + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/messages.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/messages.go new file mode 100644 index 000000000..f9e44bb1e --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/messages.go @@ -0,0 +1,724 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + "reflect" + "strconv" +) + +// These are SSH message type numbers. They are scattered around several +// documents but many were taken from [SSH-PARAMETERS]. +const ( + msgIgnore = 2 + msgUnimplemented = 3 + msgDebug = 4 + msgNewKeys = 21 + + // Standard authentication messages + msgUserAuthSuccess = 52 + msgUserAuthBanner = 53 +) + +// SSH messages: +// +// These structures mirror the wire format of the corresponding SSH messages. +// They are marshaled using reflection with the marshal and unmarshal functions +// in this file. The only wrinkle is that a final member of type []byte with a +// ssh tag of "rest" receives the remainder of a packet when unmarshaling. + +// See RFC 4253, section 11.1. +const msgDisconnect = 1 + +// disconnectMsg is the message that signals a disconnect. It is also +// the error type returned from mux.Wait() +type disconnectMsg struct { + Reason uint32 `sshtype:"1"` + Message string + Language string +} + +func (d *disconnectMsg) Error() string { + return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message) +} + +// See RFC 4253, section 7.1. +const msgKexInit = 20 + +type kexInitMsg struct { + Cookie [16]byte `sshtype:"20"` + KexAlgos []string + ServerHostKeyAlgos []string + CiphersClientServer []string + CiphersServerClient []string + MACsClientServer []string + MACsServerClient []string + CompressionClientServer []string + CompressionServerClient []string + LanguagesClientServer []string + LanguagesServerClient []string + FirstKexFollows bool + Reserved uint32 +} + +// See RFC 4253, section 8. + +// Diffie-Helman +const msgKexDHInit = 30 + +type kexDHInitMsg struct { + X *big.Int `sshtype:"30"` +} + +const msgKexECDHInit = 30 + +type kexECDHInitMsg struct { + ClientPubKey []byte `sshtype:"30"` +} + +const msgKexECDHReply = 31 + +type kexECDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + EphemeralPubKey []byte + Signature []byte +} + +const msgKexDHReply = 31 + +type kexDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + Y *big.Int + Signature []byte +} + +// See RFC 4253, section 10. +const msgServiceRequest = 5 + +type serviceRequestMsg struct { + Service string `sshtype:"5"` +} + +// See RFC 4253, section 10. +const msgServiceAccept = 6 + +type serviceAcceptMsg struct { + Service string `sshtype:"6"` +} + +// See RFC 4252, section 5. +const msgUserAuthRequest = 50 + +type userAuthRequestMsg struct { + User string `sshtype:"50"` + Service string + Method string + Payload []byte `ssh:"rest"` +} + +// See RFC 4252, section 5.1 +const msgUserAuthFailure = 51 + +type userAuthFailureMsg struct { + Methods []string `sshtype:"51"` + PartialSuccess bool +} + +// See RFC 4256, section 3.2 +const msgUserAuthInfoRequest = 60 +const msgUserAuthInfoResponse = 61 + +type userAuthInfoRequestMsg struct { + User string `sshtype:"60"` + Instruction string + DeprecatedLanguage string + NumPrompts uint32 + Prompts []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpen = 90 + +type channelOpenMsg struct { + ChanType string `sshtype:"90"` + PeersId uint32 + PeersWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +const msgChannelExtendedData = 95 +const msgChannelData = 94 + +// See RFC 4254, section 5.1. +const msgChannelOpenConfirm = 91 + +type channelOpenConfirmMsg struct { + PeersId uint32 `sshtype:"91"` + MyId uint32 + MyWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenFailure = 92 + +type channelOpenFailureMsg struct { + PeersId uint32 `sshtype:"92"` + Reason RejectionReason + Message string + Language string +} + +const msgChannelRequest = 98 + +type channelRequestMsg struct { + PeersId uint32 `sshtype:"98"` + Request string + WantReply bool + RequestSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.4. +const msgChannelSuccess = 99 + +type channelRequestSuccessMsg struct { + PeersId uint32 `sshtype:"99"` +} + +// See RFC 4254, section 5.4. +const msgChannelFailure = 100 + +type channelRequestFailureMsg struct { + PeersId uint32 `sshtype:"100"` +} + +// See RFC 4254, section 5.3 +const msgChannelClose = 97 + +type channelCloseMsg struct { + PeersId uint32 `sshtype:"97"` +} + +// See RFC 4254, section 5.3 +const msgChannelEOF = 96 + +type channelEOFMsg struct { + PeersId uint32 `sshtype:"96"` +} + +// See RFC 4254, section 4 +const msgGlobalRequest = 80 + +type globalRequestMsg struct { + Type string `sshtype:"80"` + WantReply bool + Data []byte `ssh:"rest"` +} + +// See RFC 4254, section 4 +const msgRequestSuccess = 81 + +type globalRequestSuccessMsg struct { + Data []byte `ssh:"rest" sshtype:"81"` +} + +// See RFC 4254, section 4 +const msgRequestFailure = 82 + +type globalRequestFailureMsg struct { + Data []byte `ssh:"rest" sshtype:"82"` +} + +// See RFC 4254, section 5.2 +const msgChannelWindowAdjust = 93 + +type windowAdjustMsg struct { + PeersId uint32 `sshtype:"93"` + AdditionalBytes uint32 +} + +// See RFC 4252, section 7 +const msgUserAuthPubKeyOk = 60 + +type userAuthPubKeyOkMsg struct { + Algo string `sshtype:"60"` + PubKey []byte +} + +// typeTag returns the type byte for the given type. The type should +// be struct. +func typeTag(structType reflect.Type) byte { + var tag byte + var tagStr string + tagStr = structType.Field(0).Tag.Get("sshtype") + i, err := strconv.Atoi(tagStr) + if err == nil { + tag = byte(i) + } + return tag +} + +func fieldError(t reflect.Type, field int, problem string) error { + if problem != "" { + problem = ": " + problem + } + return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem) +} + +var errShortRead = errors.New("ssh: short read") + +// Unmarshal parses data in SSH wire format into a structure. The out +// argument should be a pointer to struct. If the first member of the +// struct has the "sshtype" tag set to a number in decimal, the packet +// must start that number. In case of error, Unmarshal returns a +// ParseError or UnexpectedMessageError. +func Unmarshal(data []byte, out interface{}) error { + v := reflect.ValueOf(out).Elem() + structType := v.Type() + expectedType := typeTag(structType) + if len(data) == 0 { + return parseError(expectedType) + } + if expectedType > 0 { + if data[0] != expectedType { + return unexpectedMessageError(expectedType, data[0]) + } + data = data[1:] + } + + var ok bool + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + t := field.Type() + switch t.Kind() { + case reflect.Bool: + if len(data) < 1 { + return errShortRead + } + field.SetBool(data[0] != 0) + data = data[1:] + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + return fieldError(structType, i, "array of unsupported type") + } + if len(data) < t.Len() { + return errShortRead + } + for j, n := 0, t.Len(); j < n; j++ { + field.Index(j).Set(reflect.ValueOf(data[j])) + } + data = data[t.Len():] + case reflect.Uint64: + var u64 uint64 + if u64, data, ok = parseUint64(data); !ok { + return errShortRead + } + field.SetUint(u64) + case reflect.Uint32: + var u32 uint32 + if u32, data, ok = parseUint32(data); !ok { + return errShortRead + } + field.SetUint(uint64(u32)) + case reflect.Uint8: + if len(data) < 1 { + return errShortRead + } + field.SetUint(uint64(data[0])) + data = data[1:] + case reflect.String: + var s []byte + if s, data, ok = parseString(data); !ok { + return fieldError(structType, i, "") + } + field.SetString(string(s)) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if structType.Field(i).Tag.Get("ssh") == "rest" { + field.Set(reflect.ValueOf(data)) + data = nil + } else { + var s []byte + if s, data, ok = parseString(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(s)) + } + case reflect.String: + var nl []string + if nl, data, ok = parseNameList(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(nl)) + default: + return fieldError(structType, i, "slice of unsupported type") + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + if n, data, ok = parseInt(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(n)) + } else { + return fieldError(structType, i, "pointer to unsupported type") + } + default: + return fieldError(structType, i, "unsupported type") + } + } + + if len(data) != 0 { + return parseError(expectedType) + } + + return nil +} + +// Marshal serializes the message in msg to SSH wire format. The msg +// argument should be a struct or pointer to struct. If the first +// member has the "sshtype" tag set to a number in decimal, that +// number is prepended to the result. If the last of member has the +// "ssh" tag set to "rest", its contents are appended to the output. +func Marshal(msg interface{}) []byte { + out := make([]byte, 0, 64) + return marshalStruct(out, msg) +} + +func marshalStruct(out []byte, msg interface{}) []byte { + v := reflect.Indirect(reflect.ValueOf(msg)) + msgType := typeTag(v.Type()) + if msgType > 0 { + out = append(out, msgType) + } + + for i, n := 0, v.NumField(); i < n; i++ { + field := v.Field(i) + switch t := field.Type(); t.Kind() { + case reflect.Bool: + var v uint8 + if field.Bool() { + v = 1 + } + out = append(out, v) + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface())) + } + for j, l := 0, t.Len(); j < l; j++ { + out = append(out, uint8(field.Index(j).Uint())) + } + case reflect.Uint32: + out = appendU32(out, uint32(field.Uint())) + case reflect.Uint64: + out = appendU64(out, uint64(field.Uint())) + case reflect.Uint8: + out = append(out, uint8(field.Uint())) + case reflect.String: + s := field.String() + out = appendInt(out, len(s)) + out = append(out, s...) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if v.Type().Field(i).Tag.Get("ssh") != "rest" { + out = appendInt(out, field.Len()) + } + out = append(out, field.Bytes()...) + case reflect.String: + offset := len(out) + out = appendU32(out, 0) + if n := field.Len(); n > 0 { + for j := 0; j < n; j++ { + f := field.Index(j) + if j != 0 { + out = append(out, ',') + } + out = append(out, f.String()...) + } + // overwrite length value + binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4)) + } + default: + panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + nValue := reflect.ValueOf(&n) + nValue.Elem().Set(field) + needed := intLength(n) + oldLength := len(out) + + if cap(out)-len(out) < needed { + newOut := make([]byte, len(out), 2*(len(out)+needed)) + copy(newOut, out) + out = newOut + } + out = out[:oldLength+needed] + marshalInt(out[oldLength:], n) + } else { + panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface())) + } + } + } + + return out +} + +var bigOne = big.NewInt(1) + +func parseString(in []byte) (out, rest []byte, ok bool) { + if len(in) < 4 { + return + } + length := binary.BigEndian.Uint32(in) + if uint32(len(in)) < 4+length { + return + } + out = in[4 : 4+length] + rest = in[4+length:] + ok = true + return +} + +var ( + comma = []byte{','} + emptyNameList = []string{} +) + +func parseNameList(in []byte) (out []string, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + if len(contents) == 0 { + out = emptyNameList + return + } + parts := bytes.Split(contents, comma) + out = make([]string, len(parts)) + for i, part := range parts { + out[i] = string(part) + } + return +} + +func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + out = new(big.Int) + + if len(contents) > 0 && contents[0]&0x80 == 0x80 { + // This is a negative number + notBytes := make([]byte, len(contents)) + for i := range notBytes { + notBytes[i] = ^contents[i] + } + out.SetBytes(notBytes) + out.Add(out, bigOne) + out.Neg(out) + } else { + // Positive number + out.SetBytes(contents) + } + ok = true + return +} + +func parseUint32(in []byte) (uint32, []byte, bool) { + if len(in) < 4 { + return 0, nil, false + } + return binary.BigEndian.Uint32(in), in[4:], true +} + +func parseUint64(in []byte) (uint64, []byte, bool) { + if len(in) < 8 { + return 0, nil, false + } + return binary.BigEndian.Uint64(in), in[8:], true +} + +func intLength(n *big.Int) int { + length := 4 /* length bytes */ + if n.Sign() < 0 { + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bitLen := nMinus1.BitLen() + if bitLen%8 == 0 { + // The number will need 0xff padding + length++ + } + length += (bitLen + 7) / 8 + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bitLen := n.BitLen() + if bitLen%8 == 0 { + // The number will need 0x00 padding + length++ + } + length += (bitLen + 7) / 8 + } + + return length +} + +func marshalUint32(to []byte, n uint32) []byte { + binary.BigEndian.PutUint32(to, n) + return to[4:] +} + +func marshalUint64(to []byte, n uint64) []byte { + binary.BigEndian.PutUint64(to, n) + return to[8:] +} + +func marshalInt(to []byte, n *big.Int) []byte { + lengthBytes := to + to = to[4:] + length := 0 + + if n.Sign() < 0 { + // A negative number has to be converted to two's-complement + // form. So we'll subtract 1 and invert. If the + // most-significant-bit isn't set then we'll need to pad the + // beginning with 0xff in order to keep the number negative. + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bytes := nMinus1.Bytes() + for i := range bytes { + bytes[i] ^= 0xff + } + if len(bytes) == 0 || bytes[0]&0x80 == 0 { + to[0] = 0xff + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bytes := n.Bytes() + if len(bytes) > 0 && bytes[0]&0x80 != 0 { + // We'll have to pad this with a 0x00 in order to + // stop it looking like a negative number. + to[0] = 0 + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } + + lengthBytes[0] = byte(length >> 24) + lengthBytes[1] = byte(length >> 16) + lengthBytes[2] = byte(length >> 8) + lengthBytes[3] = byte(length) + return to +} + +func writeInt(w io.Writer, n *big.Int) { + length := intLength(n) + buf := make([]byte, length) + marshalInt(buf, n) + w.Write(buf) +} + +func writeString(w io.Writer, s []byte) { + var lengthBytes [4]byte + lengthBytes[0] = byte(len(s) >> 24) + lengthBytes[1] = byte(len(s) >> 16) + lengthBytes[2] = byte(len(s) >> 8) + lengthBytes[3] = byte(len(s)) + w.Write(lengthBytes[:]) + w.Write(s) +} + +func stringLength(n int) int { + return 4 + n +} + +func marshalString(to []byte, s []byte) []byte { + to[0] = byte(len(s) >> 24) + to[1] = byte(len(s) >> 16) + to[2] = byte(len(s) >> 8) + to[3] = byte(len(s)) + to = to[4:] + copy(to, s) + return to[len(s):] +} + +var bigIntType = reflect.TypeOf((*big.Int)(nil)) + +// Decode a packet into its corresponding message. +func decode(packet []byte) (interface{}, error) { + var msg interface{} + switch packet[0] { + case msgDisconnect: + msg = new(disconnectMsg) + case msgServiceRequest: + msg = new(serviceRequestMsg) + case msgServiceAccept: + msg = new(serviceAcceptMsg) + case msgKexInit: + msg = new(kexInitMsg) + case msgKexDHInit: + msg = new(kexDHInitMsg) + case msgKexDHReply: + msg = new(kexDHReplyMsg) + case msgUserAuthRequest: + msg = new(userAuthRequestMsg) + case msgUserAuthFailure: + msg = new(userAuthFailureMsg) + case msgUserAuthPubKeyOk: + msg = new(userAuthPubKeyOkMsg) + case msgGlobalRequest: + msg = new(globalRequestMsg) + case msgRequestSuccess: + msg = new(globalRequestSuccessMsg) + case msgRequestFailure: + msg = new(globalRequestFailureMsg) + case msgChannelOpen: + msg = new(channelOpenMsg) + case msgChannelOpenConfirm: + msg = new(channelOpenConfirmMsg) + case msgChannelOpenFailure: + msg = new(channelOpenFailureMsg) + case msgChannelWindowAdjust: + msg = new(windowAdjustMsg) + case msgChannelEOF: + msg = new(channelEOFMsg) + case msgChannelClose: + msg = new(channelCloseMsg) + case msgChannelRequest: + msg = new(channelRequestMsg) + case msgChannelSuccess: + msg = new(channelRequestSuccessMsg) + case msgChannelFailure: + msg = new(channelRequestFailureMsg) + default: + return nil, unexpectedMessageError(0, packet[0]) + } + if err := Unmarshal(packet, msg); err != nil { + return nil, err + } + return msg, nil +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/messages_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/messages_test.go new file mode 100644 index 000000000..21d52daf2 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/messages_test.go @@ -0,0 +1,244 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "math/big" + "math/rand" + "reflect" + "testing" + "testing/quick" +) + +var intLengthTests = []struct { + val, length int +}{ + {0, 4 + 0}, + {1, 4 + 1}, + {127, 4 + 1}, + {128, 4 + 2}, + {-1, 4 + 1}, +} + +func TestIntLength(t *testing.T) { + for _, test := range intLengthTests { + v := new(big.Int).SetInt64(int64(test.val)) + length := intLength(v) + if length != test.length { + t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length) + } + } +} + +type msgAllTypes struct { + Bool bool `sshtype:"21"` + Array [16]byte + Uint64 uint64 + Uint32 uint32 + Uint8 uint8 + String string + Strings []string + Bytes []byte + Int *big.Int + Rest []byte `ssh:"rest"` +} + +func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value { + m := &msgAllTypes{} + m.Bool = rand.Intn(2) == 1 + randomBytes(m.Array[:], rand) + m.Uint64 = uint64(rand.Int63n(1<<63 - 1)) + m.Uint32 = uint32(rand.Intn((1 << 31) - 1)) + m.Uint8 = uint8(rand.Intn(1 << 8)) + m.String = string(m.Array[:]) + m.Strings = randomNameList(rand) + m.Bytes = m.Array[:] + m.Int = randomInt(rand) + m.Rest = m.Array[:] + return reflect.ValueOf(m) +} + +func TestMarshalUnmarshal(t *testing.T) { + rand := rand.New(rand.NewSource(0)) + iface := &msgAllTypes{} + ty := reflect.ValueOf(iface).Type() + + n := 100 + if testing.Short() { + n = 5 + } + for j := 0; j < n; j++ { + v, ok := quick.Value(ty, rand) + if !ok { + t.Errorf("failed to create value") + break + } + + m1 := v.Elem().Interface() + m2 := iface + + marshaled := Marshal(m1) + if err := Unmarshal(marshaled, m2); err != nil { + t.Errorf("Unmarshal %#v: %s", m1, err) + break + } + + if !reflect.DeepEqual(v.Interface(), m2) { + t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled) + break + } + } +} + +func TestUnmarshalEmptyPacket(t *testing.T) { + var b []byte + var m channelRequestSuccessMsg + if err := Unmarshal(b, &m); err == nil { + t.Fatalf("unmarshal of empty slice succeeded") + } +} + +func TestUnmarshalUnexpectedPacket(t *testing.T) { + type S struct { + I uint32 `sshtype:"43"` + S string + B bool + } + + s := S{11, "hello", true} + packet := Marshal(s) + packet[0] = 42 + roundtrip := S{} + err := Unmarshal(packet, &roundtrip) + if err == nil { + t.Fatal("expected error, not nil") + } +} + +func TestMarshalPtr(t *testing.T) { + s := struct { + S string + }{"hello"} + + m1 := Marshal(s) + m2 := Marshal(&s) + if !bytes.Equal(m1, m2) { + t.Errorf("got %q, want %q for marshaled pointer", m2, m1) + } +} + +func TestBareMarshalUnmarshal(t *testing.T) { + type S struct { + I uint32 + S string + B bool + } + + s := S{42, "hello", true} + packet := Marshal(s) + roundtrip := S{} + Unmarshal(packet, &roundtrip) + + if !reflect.DeepEqual(s, roundtrip) { + t.Errorf("got %#v, want %#v", roundtrip, s) + } +} + +func TestBareMarshal(t *testing.T) { + type S2 struct { + I uint32 + } + s := S2{42} + packet := Marshal(s) + i, rest, ok := parseUint32(packet) + if len(rest) > 0 || !ok { + t.Errorf("parseInt(%q): parse error", packet) + } + if i != s.I { + t.Errorf("got %d, want %d", i, s.I) + } +} + +func randomBytes(out []byte, rand *rand.Rand) { + for i := 0; i < len(out); i++ { + out[i] = byte(rand.Int31()) + } +} + +func randomNameList(rand *rand.Rand) []string { + ret := make([]string, rand.Int31()&15) + for i := range ret { + s := make([]byte, 1+(rand.Int31()&15)) + for j := range s { + s[j] = 'a' + uint8(rand.Int31()&15) + } + ret[i] = string(s) + } + return ret +} + +func randomInt(rand *rand.Rand) *big.Int { + return new(big.Int).SetInt64(int64(int32(rand.Uint32()))) +} + +func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + ki := &kexInitMsg{} + randomBytes(ki.Cookie[:], rand) + ki.KexAlgos = randomNameList(rand) + ki.ServerHostKeyAlgos = randomNameList(rand) + ki.CiphersClientServer = randomNameList(rand) + ki.CiphersServerClient = randomNameList(rand) + ki.MACsClientServer = randomNameList(rand) + ki.MACsServerClient = randomNameList(rand) + ki.CompressionClientServer = randomNameList(rand) + ki.CompressionServerClient = randomNameList(rand) + ki.LanguagesClientServer = randomNameList(rand) + ki.LanguagesServerClient = randomNameList(rand) + if rand.Int31()&1 == 1 { + ki.FirstKexFollows = true + } + return reflect.ValueOf(ki) +} + +func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + dhi := &kexDHInitMsg{} + dhi.X = randomInt(rand) + return reflect.ValueOf(dhi) +} + +var ( + _kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() + _kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() + + _kexInit = Marshal(_kexInitMsg) + _kexDHInit = Marshal(_kexDHInitMsg) +) + +func BenchmarkMarshalKexInitMsg(b *testing.B) { + for i := 0; i < b.N; i++ { + Marshal(_kexInitMsg) + } +} + +func BenchmarkUnmarshalKexInitMsg(b *testing.B) { + m := new(kexInitMsg) + for i := 0; i < b.N; i++ { + Unmarshal(_kexInit, m) + } +} + +func BenchmarkMarshalKexDHInitMsg(b *testing.B) { + for i := 0; i < b.N; i++ { + Marshal(_kexDHInitMsg) + } +} + +func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) { + m := new(kexDHInitMsg) + for i := 0; i < b.N; i++ { + Unmarshal(_kexDHInit, m) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/mux.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/mux.go new file mode 100644 index 000000000..321880ad9 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/mux.go @@ -0,0 +1,356 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "sync" + "sync/atomic" +) + +// debugMux, if set, causes messages in the connection protocol to be +// logged. +const debugMux = false + +// chanList is a thread safe channel list. +type chanList struct { + // protects concurrent access to chans + sync.Mutex + + // chans are indexed by the local id of the channel, which the + // other side should send in the PeersId field. + chans []*channel + + // This is a debugging aid: it offsets all IDs by this + // amount. This helps distinguish otherwise identical + // server/client muxes + offset uint32 +} + +// Assigns a channel ID to the given channel. +func (c *chanList) add(ch *channel) uint32 { + c.Lock() + defer c.Unlock() + for i := range c.chans { + if c.chans[i] == nil { + c.chans[i] = ch + return uint32(i) + c.offset + } + } + c.chans = append(c.chans, ch) + return uint32(len(c.chans)-1) + c.offset +} + +// getChan returns the channel for the given ID. +func (c *chanList) getChan(id uint32) *channel { + id -= c.offset + + c.Lock() + defer c.Unlock() + if id < uint32(len(c.chans)) { + return c.chans[id] + } + return nil +} + +func (c *chanList) remove(id uint32) { + id -= c.offset + c.Lock() + if id < uint32(len(c.chans)) { + c.chans[id] = nil + } + c.Unlock() +} + +// dropAll forgets all channels it knows, returning them in a slice. +func (c *chanList) dropAll() []*channel { + c.Lock() + defer c.Unlock() + var r []*channel + + for _, ch := range c.chans { + if ch == nil { + continue + } + r = append(r, ch) + } + c.chans = nil + return r +} + +// mux represents the state for the SSH connection protocol, which +// multiplexes many channels onto a single packet transport. +type mux struct { + conn packetConn + chanList chanList + + incomingChannels chan NewChannel + + globalSentMu sync.Mutex + globalResponses chan interface{} + incomingRequests chan *Request + + errCond *sync.Cond + err error +} + +// When debugging, each new chanList instantiation has a different +// offset. +var globalOff uint32 + +func (m *mux) Wait() error { + m.errCond.L.Lock() + defer m.errCond.L.Unlock() + for m.err == nil { + m.errCond.Wait() + } + return m.err +} + +// newMux returns a mux that runs over the given connection. +func newMux(p packetConn) *mux { + m := &mux{ + conn: p, + incomingChannels: make(chan NewChannel, 16), + globalResponses: make(chan interface{}, 1), + incomingRequests: make(chan *Request, 16), + errCond: newCond(), + } + if debugMux { + m.chanList.offset = atomic.AddUint32(&globalOff, 1) + } + + go m.loop() + return m +} + +func (m *mux) sendMessage(msg interface{}) error { + p := Marshal(msg) + return m.conn.writePacket(p) +} + +func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { + if wantReply { + m.globalSentMu.Lock() + defer m.globalSentMu.Unlock() + } + + if err := m.sendMessage(globalRequestMsg{ + Type: name, + WantReply: wantReply, + Data: payload, + }); err != nil { + return false, nil, err + } + + if !wantReply { + return false, nil, nil + } + + msg, ok := <-m.globalResponses + if !ok { + return false, nil, io.EOF + } + switch msg := msg.(type) { + case *globalRequestFailureMsg: + return false, msg.Data, nil + case *globalRequestSuccessMsg: + return true, msg.Data, nil + default: + return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) + } +} + +// ackRequest must be called after processing a global request that +// has WantReply set. +func (m *mux) ackRequest(ok bool, data []byte) error { + if ok { + return m.sendMessage(globalRequestSuccessMsg{Data: data}) + } + return m.sendMessage(globalRequestFailureMsg{Data: data}) +} + +// TODO(hanwen): Disconnect is a transport layer message. We should +// probably send and receive Disconnect somewhere in the transport +// code. + +// Disconnect sends a disconnect message. +func (m *mux) Disconnect(reason uint32, message string) error { + return m.sendMessage(disconnectMsg{ + Reason: reason, + Message: message, + }) +} + +func (m *mux) Close() error { + return m.conn.Close() +} + +// loop runs the connection machine. It will process packets until an +// error is encountered. To synchronize on loop exit, use mux.Wait. +func (m *mux) loop() { + var err error + for err == nil { + err = m.onePacket() + } + + for _, ch := range m.chanList.dropAll() { + ch.close() + } + + close(m.incomingChannels) + close(m.incomingRequests) + close(m.globalResponses) + + m.conn.Close() + + m.errCond.L.Lock() + m.err = err + m.errCond.Broadcast() + m.errCond.L.Unlock() + + if debugMux { + log.Println("loop exit", err) + } +} + +// onePacket reads and processes one packet. +func (m *mux) onePacket() error { + packet, err := m.conn.readPacket() + if err != nil { + return err + } + + if debugMux { + if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { + log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) + } else { + p, _ := decode(packet) + log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) + } + } + + switch packet[0] { + case msgNewKeys: + // Ignore notification of key change. + return nil + case msgDisconnect: + return m.handleDisconnect(packet) + case msgChannelOpen: + return m.handleChannelOpen(packet) + case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: + return m.handleGlobalPacket(packet) + } + + // assume a channel packet. + if len(packet) < 5 { + return parseError(packet[0]) + } + id := binary.BigEndian.Uint32(packet[1:]) + ch := m.chanList.getChan(id) + if ch == nil { + return fmt.Errorf("ssh: invalid channel %d", id) + } + + return ch.handlePacket(packet) +} + +func (m *mux) handleDisconnect(packet []byte) error { + var d disconnectMsg + if err := Unmarshal(packet, &d); err != nil { + return err + } + + if debugMux { + log.Printf("caught disconnect: %v", d) + } + return &d +} + +func (m *mux) handleGlobalPacket(packet []byte) error { + msg, err := decode(packet) + if err != nil { + return err + } + + switch msg := msg.(type) { + case *globalRequestMsg: + m.incomingRequests <- &Request{ + Type: msg.Type, + WantReply: msg.WantReply, + Payload: msg.Data, + mux: m, + } + case *globalRequestSuccessMsg, *globalRequestFailureMsg: + m.globalResponses <- msg + default: + panic(fmt.Sprintf("not a global message %#v", msg)) + } + + return nil +} + +// handleChannelOpen schedules a channel to be Accept()ed. +func (m *mux) handleChannelOpen(packet []byte) error { + var msg channelOpenMsg + if err := Unmarshal(packet, &msg); err != nil { + return err + } + + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + failMsg := channelOpenFailureMsg{ + PeersId: msg.PeersId, + Reason: ConnectionFailed, + Message: "invalid request", + Language: "en_US.UTF-8", + } + return m.sendMessage(failMsg) + } + + c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) + c.remoteId = msg.PeersId + c.maxRemotePayload = msg.MaxPacketSize + c.remoteWin.add(msg.PeersWindow) + m.incomingChannels <- c + return nil +} + +func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { + ch, err := m.openChannel(chanType, extra) + if err != nil { + return nil, nil, err + } + + return ch, ch.incomingRequests, nil +} + +func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { + ch := m.newChannel(chanType, channelOutbound, extra) + + ch.maxIncomingPayload = channelMaxPacket + + open := channelOpenMsg{ + ChanType: chanType, + PeersWindow: ch.myWindow, + MaxPacketSize: ch.maxIncomingPayload, + TypeSpecificData: extra, + PeersId: ch.localId, + } + if err := m.sendMessage(open); err != nil { + return nil, err + } + + switch msg := (<-ch.msg).(type) { + case *channelOpenConfirmMsg: + return ch, nil + case *channelOpenFailureMsg: + return nil, &OpenChannelError{msg.Reason, msg.Message} + default: + return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/mux_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/mux_test.go new file mode 100644 index 000000000..523038960 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/mux_test.go @@ -0,0 +1,525 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "io/ioutil" + "sync" + "testing" +) + +func muxPair() (*mux, *mux) { + a, b := memPipe() + + s := newMux(a) + c := newMux(b) + + return s, c +} + +// Returns both ends of a channel, and the mux for the the 2nd +// channel. +func channelPair(t *testing.T) (*channel, *channel, *mux) { + c, s := muxPair() + + res := make(chan *channel, 1) + go func() { + newCh, ok := <-s.incomingChannels + if !ok { + t.Fatalf("No incoming channel") + } + if newCh.ChannelType() != "chan" { + t.Fatalf("got type %q want chan", newCh.ChannelType()) + } + ch, _, err := newCh.Accept() + if err != nil { + t.Fatalf("Accept %v", err) + } + res <- ch.(*channel) + }() + + ch, err := c.openChannel("chan", nil) + if err != nil { + t.Fatalf("OpenChannel: %v", err) + } + + return <-res, ch, c +} + +// Test that stderr and stdout can be addressed from different +// goroutines. This is intended for use with the race detector. +func TestMuxChannelExtendedThreadSafety(t *testing.T) { + writer, reader, mux := channelPair(t) + defer writer.Close() + defer reader.Close() + defer mux.Close() + + var wr, rd sync.WaitGroup + magic := "hello world" + + wr.Add(2) + go func() { + io.WriteString(writer, magic) + wr.Done() + }() + go func() { + io.WriteString(writer.Stderr(), magic) + wr.Done() + }() + + rd.Add(2) + go func() { + c, err := ioutil.ReadAll(reader) + if string(c) != magic { + t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err) + } + rd.Done() + }() + go func() { + c, err := ioutil.ReadAll(reader.Stderr()) + if string(c) != magic { + t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err) + } + rd.Done() + }() + + wr.Wait() + writer.CloseWrite() + rd.Wait() +} + +func TestMuxReadWrite(t *testing.T) { + s, c, mux := channelPair(t) + defer s.Close() + defer c.Close() + defer mux.Close() + + magic := "hello world" + magicExt := "hello stderr" + go func() { + _, err := s.Write([]byte(magic)) + if err != nil { + t.Fatalf("Write: %v", err) + } + _, err = s.Extended(1).Write([]byte(magicExt)) + if err != nil { + t.Fatalf("Write: %v", err) + } + err = s.Close() + if err != nil { + t.Fatalf("Close: %v", err) + } + }() + + var buf [1024]byte + n, err := c.Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + got := string(buf[:n]) + if got != magic { + t.Fatalf("server: got %q want %q", got, magic) + } + + n, err = c.Extended(1).Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + + got = string(buf[:n]) + if got != magicExt { + t.Fatalf("server: got %q want %q", got, magic) + } +} + +func TestMuxChannelOverflow(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + wDone := make(chan int, 1) + go func() { + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + writer.Write(make([]byte, 1)) + wDone <- 1 + }() + writer.remoteWin.waitWriterBlocked() + + // Send 1 byte. + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], writer.remoteId) + marshalUint32(packet[5:], uint32(1)) + packet[9] = 42 + + if err := writer.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + if _, err := reader.SendRequest("hello", true, nil); err == nil { + t.Errorf("SendRequest succeeded.") + } + <-wDone +} + +func TestMuxChannelCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + wDone := make(chan int, 1) + go func() { + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + wDone <- 1 + }() + + writer.remoteWin.waitWriterBlocked() + reader.Close() + <-wDone +} + +func TestMuxConnectionCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + wDone := make(chan int, 1) + go func() { + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + wDone <- 1 + }() + + writer.remoteWin.waitWriterBlocked() + mux.Close() + <-wDone +} + +func TestMuxReject(t *testing.T) { + client, server := muxPair() + defer server.Close() + defer client.Close() + + go func() { + ch, ok := <-server.incomingChannels + if !ok { + t.Fatalf("Accept") + } + if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { + t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) + } + ch.Reject(RejectionReason(42), "message") + }() + + ch, err := client.openChannel("ch", []byte("extra")) + if ch != nil { + t.Fatal("openChannel not rejected") + } + + ocf, ok := err.(*OpenChannelError) + if !ok { + t.Errorf("got %#v want *OpenChannelError", err) + } else if ocf.Reason != 42 || ocf.Message != "message" { + t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") + } + + want := "ssh: rejected: unknown reason 42 (message)" + if err.Error() != want { + t.Errorf("got %q, want %q", err.Error(), want) + } +} + +func TestMuxChannelRequest(t *testing.T) { + client, server, mux := channelPair(t) + defer server.Close() + defer client.Close() + defer mux.Close() + + var received int + var wg sync.WaitGroup + wg.Add(1) + go func() { + for r := range server.incomingRequests { + received++ + r.Reply(r.Type == "yes", nil) + } + wg.Done() + }() + _, err := client.SendRequest("yes", false, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + ok, err := client.SendRequest("yes", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + + if !ok { + t.Errorf("SendRequest(yes): %v", ok) + + } + + ok, err = client.SendRequest("no", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + if ok { + t.Errorf("SendRequest(no): %v", ok) + + } + + client.Close() + wg.Wait() + + if received != 3 { + t.Errorf("got %d requests, want %d", received, 3) + } +} + +func TestMuxGlobalRequest(t *testing.T) { + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + var seen bool + go func() { + for r := range serverMux.incomingRequests { + seen = seen || r.Type == "peek" + if r.WantReply { + err := r.Reply(r.Type == "yes", + append([]byte(r.Type), r.Payload...)) + if err != nil { + t.Errorf("AckRequest: %v", err) + } + } + } + }() + + _, _, err := clientMux.SendRequest("peek", false, nil) + if err != nil { + t.Errorf("SendRequest: %v", err) + } + + ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) + if !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + + if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil { + t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", + ok, data, err) + } + + clientMux.Disconnect(0, "") + if !seen { + t.Errorf("never saw 'peek' request") + } +} + +func TestMuxGlobalRequestUnblock(t *testing.T) { + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + result := make(chan error, 1) + go func() { + _, _, err := clientMux.SendRequest("hello", true, nil) + result <- err + }() + + <-serverMux.incomingRequests + serverMux.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", io.EOF) + } +} + +func TestMuxChannelRequestUnblock(t *testing.T) { + a, b, connB := channelPair(t) + defer a.Close() + defer b.Close() + defer connB.Close() + + result := make(chan error, 1) + go func() { + _, err := a.SendRequest("hello", true, nil) + result <- err + }() + + <-b.incomingRequests + connB.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", err) + } +} + +func TestMuxDisconnect(t *testing.T) { + a, b := muxPair() + defer a.Close() + defer b.Close() + + go func() { + for r := range b.incomingRequests { + r.Reply(true, nil) + } + }() + + a.Disconnect(42, "whatever") + ok, _, err := a.SendRequest("hello", true, nil) + if ok || err == nil { + t.Errorf("got reply after disconnecting") + } + err = b.Wait() + if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 { + t.Errorf("got %#v, want disconnectMsg{Reason:42}", err) + } +} + +func TestMuxCloseChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + defer r.Close() + defer w.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.Close(); err != nil { + t.Errorf("w.Close: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after Close", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxCloseWriteChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.CloseWrite(); err != nil { + t.Errorf("w.CloseWrite: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after CloseWrite", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxInvalidRecord(t *testing.T) { + a, b := muxPair() + defer a.Close() + defer b.Close() + + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], 29348723 /* invalid channel id */) + marshalUint32(packet[5:], 1) + packet[9] = 42 + + a.conn.writePacket(packet) + go a.SendRequest("hello", false, nil) + // 'a' wrote an invalid packet, so 'b' has exited. + req, ok := <-b.incomingRequests + if ok { + t.Errorf("got request %#v after receiving invalid packet", req) + } +} + +func TestZeroWindowAdjust(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + go func() { + io.WriteString(a, "hello") + // bogus adjust. + a.sendMessage(windowAdjustMsg{}) + io.WriteString(a, "world") + a.Close() + }() + + want := "helloworld" + c, _ := ioutil.ReadAll(b) + if string(c) != want { + t.Errorf("got %q want %q", c, want) + } +} + +func TestMuxMaxPacketSize(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + large := make([]byte, a.maxRemotePayload+1) + packet := make([]byte, 1+4+4+1+len(large)) + packet[0] = msgChannelData + marshalUint32(packet[1:], a.remoteId) + marshalUint32(packet[5:], uint32(len(large))) + packet[9] = 42 + + if err := a.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + + go a.SendRequest("hello", false, nil) + + _, ok := <-b.incomingRequests + if ok { + t.Errorf("connection still alive after receiving large packet.") + } +} + +// Don't ship code with debug=true. +func TestDebug(t *testing.T) { + if debugMux { + t.Error("mux debug switched on") + } + if debugHandshake { + t.Error("handshake debug switched on") + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/server.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/server.go new file mode 100644 index 000000000..52aee11e3 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/server.go @@ -0,0 +1,489 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" +) + +// The Permissions type holds fine-grained permissions that are +// specific to a user or a specific authentication method for a +// user. Permissions, except for "source-address", must be enforced in +// the server application layer, after successful authentication. The +// Permissions are passed on in ServerConn so a server implementation +// can honor them. +type Permissions struct { + // Critical options restrict default permissions. Common + // restrictions are "source-address" and "force-command". If + // the server cannot enforce the restriction, or does not + // recognize it, the user should not authenticate. + CriticalOptions map[string]string + + // Extensions are extra functionality that the server may + // offer on authenticated connections. Common extensions are + // "permit-agent-forwarding", "permit-X11-forwarding". Lack of + // support for an extension does not preclude authenticating a + // user. + Extensions map[string]string +} + +// ServerConfig holds server specific configuration data. +type ServerConfig struct { + // Config contains configuration shared between client and server. + Config + + hostKeys []Signer + + // NoClientAuth is true if clients are allowed to connect without + // authenticating. + NoClientAuth bool + + // PasswordCallback, if non-nil, is called when a user + // attempts to authenticate using a password. + PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) + + // PublicKeyCallback, if non-nil, is called when a client attempts public + // key authentication. It must return true if the given public key is + // valid for the given user. For example, see CertChecker.Authenticate. + PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // KeyboardInteractiveCallback, if non-nil, is called when + // keyboard-interactive authentication is selected (RFC + // 4256). The client object's Challenge function should be + // used to query the user. The callback may offer multiple + // Challenge rounds. To avoid information leaks, the client + // should be presented a challenge even if the user is + // unknown. + KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) + + // AuthLogCallback, if non-nil, is called to log all authentication + // attempts. + AuthLogCallback func(conn ConnMetadata, method string, err error) + + // ServerVersion is the version identification string to + // announce in the public handshake. + // If empty, a reasonable default is used. + ServerVersion string +} + +// AddHostKey adds a private key as a host key. If an existing host +// key exists with the same algorithm, it is overwritten. Each server +// config must have at least one host key. +func (s *ServerConfig) AddHostKey(key Signer) { + for i, k := range s.hostKeys { + if k.PublicKey().Type() == key.PublicKey().Type() { + s.hostKeys[i] = key + return + } + } + + s.hostKeys = append(s.hostKeys, key) +} + +// cachedPubKey contains the results of querying whether a public key is +// acceptable for a user. +type cachedPubKey struct { + user string + pubKeyData []byte + result error + perms *Permissions +} + +const maxCachedPubKeys = 16 + +// pubKeyCache caches tests for public keys. Since SSH clients +// will query whether a public key is acceptable before attempting to +// authenticate with it, we end up with duplicate queries for public +// key validity. The cache only applies to a single ServerConn. +type pubKeyCache struct { + keys []cachedPubKey +} + +// get returns the result for a given user/algo/key tuple. +func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) { + for _, k := range c.keys { + if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) { + return k, true + } + } + return cachedPubKey{}, false +} + +// add adds the given tuple to the cache. +func (c *pubKeyCache) add(candidate cachedPubKey) { + if len(c.keys) < maxCachedPubKeys { + c.keys = append(c.keys, candidate) + } +} + +// ServerConn is an authenticated SSH connection, as seen from the +// server +type ServerConn struct { + Conn + + // If the succeeding authentication callback returned a + // non-nil Permissions pointer, it is stored here. + Permissions *Permissions +} + +// NewServerConn starts a new SSH server with c as the underlying +// transport. It starts with a handshake and, if the handshake is +// unsuccessful, it closes the connection and returns an error. The +// Request and NewChannel channels must be serviced, or the connection +// will hang. +func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + s := &connection{ + sshConn: sshConn{conn: c}, + } + perms, err := s.serverHandshake(&fullConf) + if err != nil { + c.Close() + return nil, nil, nil, err + } + return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil +} + +// signAndMarshal signs the data with the appropriate algorithm, +// and serializes the result in SSH wire format. +func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) { + sig, err := k.Sign(rand, data) + if err != nil { + return nil, err + } + + return Marshal(sig), nil +} + +// handshake performs key exchange and user authentication. +func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) { + if len(config.hostKeys) == 0 { + return nil, errors.New("ssh: server has no host keys") + } + + if config.ServerVersion != "" { + s.serverVersion = []byte(config.ServerVersion) + } else { + s.serverVersion = []byte(packageVersion) + } + var err error + s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion) + if err != nil { + return nil, err + } + + tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */) + s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config) + + if err := s.transport.requestKeyChange(); err != nil { + return nil, err + } + + if packet, err := s.transport.readPacket(); err != nil { + return nil, err + } else if packet[0] != msgNewKeys { + return nil, unexpectedMessageError(msgNewKeys, packet[0]) + } + + // We just did the key change, so the session ID is established. + s.sessionID = s.transport.getSessionID() + + var packet []byte + if packet, err = s.transport.readPacket(); err != nil { + return nil, err + } + + var serviceRequest serviceRequestMsg + if err = Unmarshal(packet, &serviceRequest); err != nil { + return nil, err + } + if serviceRequest.Service != serviceUserAuth { + return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") + } + serviceAccept := serviceAcceptMsg{ + Service: serviceUserAuth, + } + if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil { + return nil, err + } + + perms, err := s.serverAuthenticate(config) + if err != nil { + return nil, err + } + s.mux = newMux(s.transport) + return perms, err +} + +func isAcceptableAlgo(algo string) bool { + switch algo { + case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, + CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: + return true + } + return false +} + +func checkSourceAddress(addr net.Addr, sourceAddr string) error { + if addr == nil { + return errors.New("ssh: no address known for client, but source-address match required") + } + + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr) + } + + if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil { + if bytes.Equal(allowedIP, tcpAddr.IP) { + return nil + } + } else { + _, ipNet, err := net.ParseCIDR(sourceAddr) + if err != nil { + return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err) + } + + if ipNet.Contains(tcpAddr.IP) { + return nil + } + } + + return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) +} + +func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { + var err error + var cache pubKeyCache + var perms *Permissions + +userAuthLoop: + for { + var userAuthReq userAuthRequestMsg + if packet, err := s.transport.readPacket(); err != nil { + return nil, err + } else if err = Unmarshal(packet, &userAuthReq); err != nil { + return nil, err + } + + if userAuthReq.Service != serviceSSH { + return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) + } + + s.user = userAuthReq.User + perms = nil + authErr := errors.New("no auth passed yet") + + switch userAuthReq.Method { + case "none": + if config.NoClientAuth { + s.user = "" + authErr = nil + } + case "password": + if config.PasswordCallback == nil { + authErr = errors.New("ssh: password auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 || payload[0] != 0 { + return nil, parseError(msgUserAuthRequest) + } + payload = payload[1:] + password, payload, ok := parseString(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + + perms, authErr = config.PasswordCallback(s, password) + case "keyboard-interactive": + if config.KeyboardInteractiveCallback == nil { + authErr = errors.New("ssh: keyboard-interactive auth not configubred") + break + } + + prompter := &sshClientKeyboardInteractive{s} + perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge) + case "publickey": + if config.PublicKeyCallback == nil { + authErr = errors.New("ssh: publickey auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 { + return nil, parseError(msgUserAuthRequest) + } + isQuery := payload[0] == 0 + payload = payload[1:] + algoBytes, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + algo := string(algoBytes) + if !isAcceptableAlgo(algo) { + authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo) + break + } + + pubKeyData, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + + pubKey, err := ParsePublicKey(pubKeyData) + if err != nil { + return nil, err + } + + candidate, ok := cache.get(s.user, pubKeyData) + if !ok { + candidate.user = s.user + candidate.pubKeyData = pubKeyData + candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey) + if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { + candidate.result = checkSourceAddress( + s.RemoteAddr(), + candidate.perms.CriticalOptions[sourceAddressCriticalOption]) + } + cache.add(candidate) + } + + if isQuery { + // The client can query if the given public key + // would be okay. + if len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + + if candidate.result == nil { + okMsg := userAuthPubKeyOkMsg{ + Algo: algo, + PubKey: pubKeyData, + } + if err = s.transport.writePacket(Marshal(&okMsg)); err != nil { + return nil, err + } + continue userAuthLoop + } + authErr = candidate.result + } else { + sig, payload, ok := parseSignature(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + // Ensure the public key algo and signature algo + // are supported. Compare the private key + // algorithm name that corresponds to algo with + // sig.Format. This is usually the same, but + // for certs, the names differ. + if !isAcceptableAlgo(sig.Format) { + break + } + signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData) + + if err := pubKey.Verify(signedData, sig); err != nil { + return nil, err + } + + authErr = candidate.result + perms = candidate.perms + } + default: + authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method) + } + + if config.AuthLogCallback != nil { + config.AuthLogCallback(s, userAuthReq.Method, authErr) + } + + if authErr == nil { + break userAuthLoop + } + + var failureMsg userAuthFailureMsg + if config.PasswordCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "password") + } + if config.PublicKeyCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "publickey") + } + if config.KeyboardInteractiveCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") + } + + if len(failureMsg.Methods) == 0 { + return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") + } + + if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil { + return nil, err + } + } + + if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { + return nil, err + } + return perms, nil +} + +// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by +// asking the client on the other side of a ServerConn. +type sshClientKeyboardInteractive struct { + *connection +} + +func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) { + if len(questions) != len(echos) { + return nil, errors.New("ssh: echos and questions must have equal length") + } + + var prompts []byte + for i := range questions { + prompts = appendString(prompts, questions[i]) + prompts = appendBool(prompts, echos[i]) + } + + if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{ + Instruction: instruction, + NumPrompts: uint32(len(questions)), + Prompts: prompts, + })); err != nil { + return nil, err + } + + packet, err := c.transport.readPacket() + if err != nil { + return nil, err + } + if packet[0] != msgUserAuthInfoResponse { + return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0]) + } + packet = packet[1:] + + n, packet, ok := parseUint32(packet) + if !ok || int(n) != len(questions) { + return nil, parseError(msgUserAuthInfoResponse) + } + + for i := uint32(0); i < n; i++ { + ans, rest, ok := parseString(packet) + if !ok { + return nil, parseError(msgUserAuthInfoResponse) + } + + answers = append(answers, string(ans)) + packet = rest + } + if len(packet) != 0 { + return nil, errors.New("ssh: junk at end of message") + } + + return answers, nil +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/session.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/session.go new file mode 100644 index 000000000..3b42b508a --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/session.go @@ -0,0 +1,605 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Session implements an interactive session described in +// "RFC 4254, section 6". + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "sync" +) + +type Signal string + +// POSIX signals as listed in RFC 4254 Section 6.10. +const ( + SIGABRT Signal = "ABRT" + SIGALRM Signal = "ALRM" + SIGFPE Signal = "FPE" + SIGHUP Signal = "HUP" + SIGILL Signal = "ILL" + SIGINT Signal = "INT" + SIGKILL Signal = "KILL" + SIGPIPE Signal = "PIPE" + SIGQUIT Signal = "QUIT" + SIGSEGV Signal = "SEGV" + SIGTERM Signal = "TERM" + SIGUSR1 Signal = "USR1" + SIGUSR2 Signal = "USR2" +) + +var signals = map[Signal]int{ + SIGABRT: 6, + SIGALRM: 14, + SIGFPE: 8, + SIGHUP: 1, + SIGILL: 4, + SIGINT: 2, + SIGKILL: 9, + SIGPIPE: 13, + SIGQUIT: 3, + SIGSEGV: 11, + SIGTERM: 15, +} + +type TerminalModes map[uint8]uint32 + +// POSIX terminal mode flags as listed in RFC 4254 Section 8. +const ( + tty_OP_END = 0 + VINTR = 1 + VQUIT = 2 + VERASE = 3 + VKILL = 4 + VEOF = 5 + VEOL = 6 + VEOL2 = 7 + VSTART = 8 + VSTOP = 9 + VSUSP = 10 + VDSUSP = 11 + VREPRINT = 12 + VWERASE = 13 + VLNEXT = 14 + VFLUSH = 15 + VSWTCH = 16 + VSTATUS = 17 + VDISCARD = 18 + IGNPAR = 30 + PARMRK = 31 + INPCK = 32 + ISTRIP = 33 + INLCR = 34 + IGNCR = 35 + ICRNL = 36 + IUCLC = 37 + IXON = 38 + IXANY = 39 + IXOFF = 40 + IMAXBEL = 41 + ISIG = 50 + ICANON = 51 + XCASE = 52 + ECHO = 53 + ECHOE = 54 + ECHOK = 55 + ECHONL = 56 + NOFLSH = 57 + TOSTOP = 58 + IEXTEN = 59 + ECHOCTL = 60 + ECHOKE = 61 + PENDIN = 62 + OPOST = 70 + OLCUC = 71 + ONLCR = 72 + OCRNL = 73 + ONOCR = 74 + ONLRET = 75 + CS7 = 90 + CS8 = 91 + PARENB = 92 + PARODD = 93 + TTY_OP_ISPEED = 128 + TTY_OP_OSPEED = 129 +) + +// A Session represents a connection to a remote command or shell. +type Session struct { + // Stdin specifies the remote process's standard input. + // If Stdin is nil, the remote process reads from an empty + // bytes.Buffer. + Stdin io.Reader + + // Stdout and Stderr specify the remote process's standard + // output and error. + // + // If either is nil, Run connects the corresponding file + // descriptor to an instance of ioutil.Discard. There is a + // fixed amount of buffering that is shared for the two streams. + // If either blocks it may eventually cause the remote + // command to block. + Stdout io.Writer + Stderr io.Writer + + ch Channel // the channel backing this session + started bool // true once Start, Run or Shell is invoked. + copyFuncs []func() error + errors chan error // one send per copyFunc + + // true if pipe method is active + stdinpipe, stdoutpipe, stderrpipe bool + + // stdinPipeWriter is non-nil if StdinPipe has not been called + // and Stdin was specified by the user; it is the write end of + // a pipe connecting Session.Stdin to the stdin channel. + stdinPipeWriter io.WriteCloser + + exitStatus chan error +} + +// SendRequest sends an out-of-band channel request on the SSH channel +// underlying the session. +func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + return s.ch.SendRequest(name, wantReply, payload) +} + +func (s *Session) Close() error { + return s.ch.Close() +} + +// RFC 4254 Section 6.4. +type setenvRequest struct { + Name string + Value string +} + +// Setenv sets an environment variable that will be applied to any +// command executed by Shell or Run. +func (s *Session) Setenv(name, value string) error { + msg := setenvRequest{ + Name: name, + Value: value, + } + ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: setenv failed") + } + return err +} + +// RFC 4254 Section 6.2. +type ptyRequestMsg struct { + Term string + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 + Modelist string +} + +// RequestPty requests the association of a pty with the session on the remote host. +func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { + var tm []byte + for k, v := range termmodes { + kv := struct { + Key byte + Val uint32 + }{k, v} + + tm = append(tm, Marshal(&kv)...) + } + tm = append(tm, tty_OP_END) + req := ptyRequestMsg{ + Term: term, + Columns: uint32(w), + Rows: uint32(h), + Width: uint32(w * 8), + Height: uint32(h * 8), + Modelist: string(tm), + } + ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) + if err == nil && !ok { + err = errors.New("ssh: pty-req failed") + } + return err +} + +// RFC 4254 Section 6.5. +type subsystemRequestMsg struct { + Subsystem string +} + +// RequestSubsystem requests the association of a subsystem with the session on the remote host. +// A subsystem is a predefined command that runs in the background when the ssh session is initiated +func (s *Session) RequestSubsystem(subsystem string) error { + msg := subsystemRequestMsg{ + Subsystem: subsystem, + } + ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: subsystem request failed") + } + return err +} + +// RFC 4254 Section 6.9. +type signalMsg struct { + Signal string +} + +// Signal sends the given signal to the remote process. +// sig is one of the SIG* constants. +func (s *Session) Signal(sig Signal) error { + msg := signalMsg{ + Signal: string(sig), + } + + _, err := s.ch.SendRequest("signal", false, Marshal(&msg)) + return err +} + +// RFC 4254 Section 6.5. +type execMsg struct { + Command string +} + +// Start runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start or Shell. +func (s *Session) Start(cmd string) error { + if s.started { + return errors.New("ssh: session already started") + } + req := execMsg{ + Command: cmd, + } + + ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) + if err == nil && !ok { + err = fmt.Errorf("ssh: command %v failed", cmd) + } + if err != nil { + return err + } + return s.start() +} + +// Run runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start, Shell, Output, +// or CombinedOutput. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the command fails to run or doesn't complete successfully, the +// error is of type *ExitError. Other error types may be +// returned for I/O problems. +func (s *Session) Run(cmd string) error { + err := s.Start(cmd) + if err != nil { + return err + } + return s.Wait() +} + +// Output runs cmd on the remote host and returns its standard output. +func (s *Session) Output(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + var b bytes.Buffer + s.Stdout = &b + err := s.Run(cmd) + return b.Bytes(), err +} + +type singleWriter struct { + b bytes.Buffer + mu sync.Mutex +} + +func (w *singleWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + return w.b.Write(p) +} + +// CombinedOutput runs cmd on the remote host and returns its combined +// standard output and standard error. +func (s *Session) CombinedOutput(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + var b singleWriter + s.Stdout = &b + s.Stderr = &b + err := s.Run(cmd) + return b.b.Bytes(), err +} + +// Shell starts a login shell on the remote host. A Session only +// accepts one call to Run, Start, Shell, Output, or CombinedOutput. +func (s *Session) Shell() error { + if s.started { + return errors.New("ssh: session already started") + } + + ok, err := s.ch.SendRequest("shell", true, nil) + if err == nil && !ok { + return fmt.Errorf("ssh: cound not start shell") + } + if err != nil { + return err + } + return s.start() +} + +func (s *Session) start() error { + s.started = true + + type F func(*Session) + for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} { + setupFd(s) + } + + s.errors = make(chan error, len(s.copyFuncs)) + for _, fn := range s.copyFuncs { + go func(fn func() error) { + s.errors <- fn() + }(fn) + } + return nil +} + +// Wait waits for the remote command to exit. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the command fails to run or doesn't complete successfully, the +// error is of type *ExitError. Other error types may be +// returned for I/O problems. +func (s *Session) Wait() error { + if !s.started { + return errors.New("ssh: session not started") + } + waitErr := <-s.exitStatus + + if s.stdinPipeWriter != nil { + s.stdinPipeWriter.Close() + } + var copyError error + for _ = range s.copyFuncs { + if err := <-s.errors; err != nil && copyError == nil { + copyError = err + } + } + if waitErr != nil { + return waitErr + } + return copyError +} + +func (s *Session) wait(reqs <-chan *Request) error { + wm := Waitmsg{status: -1} + // Wait for msg channel to be closed before returning. + for msg := range reqs { + switch msg.Type { + case "exit-status": + d := msg.Payload + wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3]) + case "exit-signal": + var sigval struct { + Signal string + CoreDumped bool + Error string + Lang string + } + if err := Unmarshal(msg.Payload, &sigval); err != nil { + return err + } + + // Must sanitize strings? + wm.signal = sigval.Signal + wm.msg = sigval.Error + wm.lang = sigval.Lang + default: + // This handles keepalives and matches + // OpenSSH's behaviour. + if msg.WantReply { + msg.Reply(false, nil) + } + } + } + if wm.status == 0 { + return nil + } + if wm.status == -1 { + // exit-status was never sent from server + if wm.signal == "" { + return errors.New("wait: remote command exited without exit status or exit signal") + } + wm.status = 128 + if _, ok := signals[Signal(wm.signal)]; ok { + wm.status += signals[Signal(wm.signal)] + } + } + return &ExitError{wm} +} + +func (s *Session) stdin() { + if s.stdinpipe { + return + } + var stdin io.Reader + if s.Stdin == nil { + stdin = new(bytes.Buffer) + } else { + r, w := io.Pipe() + go func() { + _, err := io.Copy(w, s.Stdin) + w.CloseWithError(err) + }() + stdin, s.stdinPipeWriter = r, w + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.ch, stdin) + if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { + err = err1 + } + return err + }) +} + +func (s *Session) stdout() { + if s.stdoutpipe { + return + } + if s.Stdout == nil { + s.Stdout = ioutil.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stdout, s.ch) + return err + }) +} + +func (s *Session) stderr() { + if s.stderrpipe { + return + } + if s.Stderr == nil { + s.Stderr = ioutil.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stderr, s.ch.Stderr()) + return err + }) +} + +// sessionStdin reroutes Close to CloseWrite. +type sessionStdin struct { + io.Writer + ch Channel +} + +func (s *sessionStdin) Close() error { + return s.ch.CloseWrite() +} + +// StdinPipe returns a pipe that will be connected to the +// remote command's standard input when the command starts. +func (s *Session) StdinPipe() (io.WriteCloser, error) { + if s.Stdin != nil { + return nil, errors.New("ssh: Stdin already set") + } + if s.started { + return nil, errors.New("ssh: StdinPipe after process started") + } + s.stdinpipe = true + return &sessionStdin{s.ch, s.ch}, nil +} + +// StdoutPipe returns a pipe that will be connected to the +// remote command's standard output when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StdoutPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StdoutPipe() (io.Reader, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.started { + return nil, errors.New("ssh: StdoutPipe after process started") + } + s.stdoutpipe = true + return s.ch, nil +} + +// StderrPipe returns a pipe that will be connected to the +// remote command's standard error when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StderrPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StderrPipe() (io.Reader, error) { + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + if s.started { + return nil, errors.New("ssh: StderrPipe after process started") + } + s.stderrpipe = true + return s.ch.Stderr(), nil +} + +// newSession returns a new interactive session on the remote host. +func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { + s := &Session{ + ch: ch, + } + s.exitStatus = make(chan error, 1) + go func() { + s.exitStatus <- s.wait(reqs) + }() + + return s, nil +} + +// An ExitError reports unsuccessful completion of a remote command. +type ExitError struct { + Waitmsg +} + +func (e *ExitError) Error() string { + return e.Waitmsg.String() +} + +// Waitmsg stores the information about an exited remote command +// as reported by Wait. +type Waitmsg struct { + status int + signal string + msg string + lang string +} + +// ExitStatus returns the exit status of the remote command. +func (w Waitmsg) ExitStatus() int { + return w.status +} + +// Signal returns the exit signal of the remote command if +// it was terminated violently. +func (w Waitmsg) Signal() string { + return w.signal +} + +// Msg returns the exit message given by the remote command +func (w Waitmsg) Msg() string { + return w.msg +} + +// Lang returns the language tag. See RFC 3066 +func (w Waitmsg) Lang() string { + return w.lang +} + +func (w Waitmsg) String() string { + return fmt.Sprintf("Process exited with: %v. Reason was: %v (%v)", w.status, w.msg, w.signal) +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/session_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/session_test.go new file mode 100644 index 000000000..88e66bf48 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/session_test.go @@ -0,0 +1,680 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Session tests. + +import ( + "bytes" + crypto_rand "crypto/rand" + "io" + "io/ioutil" + "math/rand" + "testing" + + "golang.org/x/crypto/ssh/terminal" +) + +type serverType func(Channel, <-chan *Request, *testing.T) + +// dial constructs a new test server and returns a *ClientConn. +func dial(handler serverType, t *testing.T) *Client { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + go func() { + defer c1.Close() + conf := ServerConfig{ + NoClientAuth: true, + } + conf.AddHostKey(testSigners["rsa"]) + + _, chans, reqs, err := NewServerConn(c1, &conf) + if err != nil { + t.Fatalf("Unable to handshake: %v", err) + } + go DiscardRequests(reqs) + + for newCh := range chans { + if newCh.ChannelType() != "session" { + newCh.Reject(UnknownChannelType, "unknown channel type") + continue + } + + ch, inReqs, err := newCh.Accept() + if err != nil { + t.Errorf("Accept: %v", err) + continue + } + go func() { + handler(ch, inReqs, t) + }() + } + }() + + config := &ClientConfig{ + User: "testuser", + } + + conn, chans, reqs, err := NewClientConn(c2, "", config) + if err != nil { + t.Fatalf("unable to dial remote side: %v", err) + } + + return NewClient(conn, chans, reqs) +} + +// Test a simple string is returned to session.Stdout. +func TestSessionShell(t *testing.T) { + conn := dial(shellHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + stdout := new(bytes.Buffer) + session.Stdout = stdout + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %s", err) + } + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %v", err) + } + actual := stdout.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} + +// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it. + +// Test a simple string is returned via StdoutPipe. +func TestSessionStdoutPipe(t *testing.T) { + conn := dial(shellHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("Unable to request StdoutPipe(): %v", err) + } + var buf bytes.Buffer + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + done := make(chan bool, 1) + go func() { + if _, err := io.Copy(&buf, stdout); err != nil { + t.Errorf("Copy of stdout failed: %v", err) + } + done <- true + }() + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %v", err) + } + <-done + actual := buf.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} + +// Test that a simple string is returned via the Output helper, +// and that stderr is discarded. +func TestSessionOutput(t *testing.T) { + conn := dial(fixedOutputHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + buf, err := session.Output("") // cmd is ignored by fixedOutputHandler + if err != nil { + t.Error("Remote command did not exit cleanly:", err) + } + w := "this-is-stdout." + g := string(buf) + if g != w { + t.Error("Remote command did not return expected string:") + t.Logf("want %q", w) + t.Logf("got %q", g) + } +} + +// Test that both stdout and stderr are returned +// via the CombinedOutput helper. +func TestSessionCombinedOutput(t *testing.T) { + conn := dial(fixedOutputHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler + if err != nil { + t.Error("Remote command did not exit cleanly:", err) + } + const stdout = "this-is-stdout." + const stderr = "this-is-stderr." + g := string(buf) + if g != stdout+stderr && g != stderr+stdout { + t.Error("Remote command did not return expected string:") + t.Logf("want %q, or %q", stdout+stderr, stderr+stdout) + t.Logf("got %q", g) + } +} + +// Test non-0 exit status is returned correctly. +func TestExitStatusNonZero(t *testing.T) { + conn := dial(exitStatusNonZeroHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.ExitStatus() != 15 { + t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus()) + } +} + +// Test 0 exit status is returned correctly. +func TestExitStatusZero(t *testing.T) { + conn := dial(exitStatusZeroHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +// Test exit signal and status are both returned correctly. +func TestExitSignalAndStatus(t *testing.T) { + conn := dial(exitSignalAndStatusHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "TERM" || e.ExitStatus() != 15 { + t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +// Test exit signal and status are both returned correctly. +func TestKnownExitSignalOnly(t *testing.T) { + conn := dial(exitSignalHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "TERM" || e.ExitStatus() != 143 { + t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +// Test exit signal and status are both returned correctly. +func TestUnknownExitSignal(t *testing.T) { + conn := dial(exitSignalUnknownHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "SYS" || e.ExitStatus() != 128 { + t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +// Test WaitMsg is not returned if the channel closes abruptly. +func TestExitWithoutStatusOrSignal(t *testing.T) { + conn := dial(exitWithoutSignalOrStatus, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + _, ok := err.(*ExitError) + if ok { + // you can't actually test for errors.errorString + // because it's not exported. + t.Fatalf("expected *errorString but got %T", err) + } +} + +// windowTestBytes is the number of bytes that we'll send to the SSH server. +const windowTestBytes = 16000 * 200 + +// TestServerWindow writes random data to the server. The server is expected to echo +// the same data back, which is compared against the original. +func TestServerWindow(t *testing.T) { + origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) + io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes) + origBytes := origBuf.Bytes() + + conn := dial(echoHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + result := make(chan []byte) + + go func() { + defer close(result) + echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) + serverStdout, err := session.StdoutPipe() + if err != nil { + t.Errorf("StdoutPipe failed: %v", err) + return + } + n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes) + if err != nil && err != io.EOF { + t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err) + } + result <- echoedBuf.Bytes() + }() + + serverStdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes) + if err != nil { + t.Fatalf("failed to copy origBuf to serverStdin: %v", err) + } + if written != windowTestBytes { + t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes) + } + + echoedBytes := <-result + + if !bytes.Equal(origBytes, echoedBytes) { + t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes)) + } +} + +// Verify the client can handle a keepalive packet from the server. +func TestClientHandlesKeepalives(t *testing.T) { + conn := dial(channelKeepaliveSender, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err != nil { + t.Fatalf("expected nil but got: %v", err) + } +} + +type exitStatusMsg struct { + Status uint32 +} + +type exitSignalMsg struct { + Signal string + CoreDumped bool + Errmsg string + Lang string +} + +func handleTerminalRequests(in <-chan *Request) { + for req := range in { + ok := false + switch req.Type { + case "shell": + ok = true + if len(req.Payload) > 0 { + // We don't accept any commands, only the default shell. + ok = false + } + case "env": + ok = true + } + req.Reply(ok, nil) + } +} + +func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal { + term := terminal.NewTerminal(ch, prompt) + go handleTerminalRequests(in) + return term +} + +func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + // this string is returned to stdout + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(0, ch, t) +} + +func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(15, ch, t) +} + +func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(15, ch, t) + sendSignal("TERM", ch, t) +} + +func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendSignal("TERM", ch, t) +} + +func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendSignal("SYS", ch, t) +} + +func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) +} + +func shellHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + // this string is returned to stdout + shell := newServerShell(ch, in, "golang") + readLine(shell, t) + sendStatus(0, ch, t) +} + +// Ignores the command, writes fixed strings to stderr and stdout. +// Strings are "this-is-stdout." and "this-is-stderr.". +func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + _, err := ch.Read(nil) + + req, ok := <-in + if !ok { + t.Fatalf("error: expected channel request, got: %#v", err) + return + } + + // ignore request, always send some text + req.Reply(true, nil) + + _, err = io.WriteString(ch, "this-is-stdout.") + if err != nil { + t.Fatalf("error writing on server: %v", err) + } + _, err = io.WriteString(ch.Stderr(), "this-is-stderr.") + if err != nil { + t.Fatalf("error writing on server: %v", err) + } + sendStatus(0, ch, t) +} + +func readLine(shell *terminal.Terminal, t *testing.T) { + if _, err := shell.ReadLine(); err != nil && err != io.EOF { + t.Errorf("unable to read line: %v", err) + } +} + +func sendStatus(status uint32, ch Channel, t *testing.T) { + msg := exitStatusMsg{ + Status: status, + } + if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil { + t.Errorf("unable to send status: %v", err) + } +} + +func sendSignal(signal string, ch Channel, t *testing.T) { + sig := exitSignalMsg{ + Signal: signal, + CoreDumped: false, + Errmsg: "Process terminated", + Lang: "en-GB-oed", + } + if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil { + t.Errorf("unable to send signal: %v", err) + } +} + +func discardHandler(ch Channel, t *testing.T) { + defer ch.Close() + io.Copy(ioutil.Discard, ch) +} + +func echoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil { + t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err) + } +} + +// copyNRandomly copies n bytes from src to dst. It uses a variable, and random, +// buffer size to exercise more code paths. +func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) { + var ( + buf = make([]byte, 32*1024) + written int + remaining = n + ) + for remaining > 0 { + l := rand.Intn(1 << 15) + if remaining < l { + l = remaining + } + nr, er := src.Read(buf[:l]) + nw, ew := dst.Write(buf[:nr]) + remaining -= nw + written += nw + if ew != nil { + return written, ew + } + if nr != nw { + return written, io.ErrShortWrite + } + if er != nil && er != io.EOF { + return written, er + } + } + return written, nil +} + +func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil { + t.Errorf("unable to send channel keepalive request: %v", err) + } + sendStatus(0, ch, t) +} + +func TestClientWriteEOF(t *testing.T) { + conn := dial(simpleEchoHandler, t) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + stdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("StdoutPipe failed: %v", err) + } + + data := []byte(`0000`) + _, err = stdin.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + stdin.Close() + + res, err := ioutil.ReadAll(stdout) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + if !bytes.Equal(data, res) { + t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res) + } +} + +func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + data, err := ioutil.ReadAll(ch) + if err != nil { + t.Errorf("handler read error: %v", err) + } + _, err = ch.Write(data) + if err != nil { + t.Errorf("handler write error: %v", err) + } +} + +func TestSessionID(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverID := make(chan []byte, 1) + clientID := make(chan []byte, 1) + + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["ecdsa"]) + clientConf := &ClientConfig{ + User: "user", + } + + go func() { + conn, chans, reqs, err := NewServerConn(c1, serverConf) + if err != nil { + t.Fatalf("server handshake: %v", err) + } + serverID <- conn.SessionID() + go DiscardRequests(reqs) + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + go func() { + conn, chans, reqs, err := NewClientConn(c2, "", clientConf) + if err != nil { + t.Fatalf("client handshake: %v", err) + } + clientID <- conn.SessionID() + go DiscardRequests(reqs) + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + s := <-serverID + c := <-clientID + if bytes.Compare(s, c) != 0 { + t.Errorf("server session ID (%x) != client session ID (%x)", s, c) + } else if len(s) == 0 { + t.Errorf("client and server SessionID were empty.") + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/tcpip.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/tcpip.go new file mode 100644 index 000000000..4ecad0b3f --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/tcpip.go @@ -0,0 +1,404 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "fmt" + "io" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "time" +) + +// Listen requests the remote peer open a listening socket on +// addr. Incoming connections will be available by calling Accept on +// the returned net.Listener. The listener must be serviced, or the +// SSH connection may hang. +func (c *Client) Listen(n, addr string) (net.Listener, error) { + laddr, err := net.ResolveTCPAddr(n, addr) + if err != nil { + return nil, err + } + return c.ListenTCP(laddr) +} + +// Automatic port allocation is broken with OpenSSH before 6.0. See +// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In +// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0, +// rather than the actual port number. This means you can never open +// two different listeners with auto allocated ports. We work around +// this by trying explicit ports until we succeed. + +const openSSHPrefix = "OpenSSH_" + +var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano())) + +// isBrokenOpenSSHVersion returns true if the given version string +// specifies a version of OpenSSH that is known to have a bug in port +// forwarding. +func isBrokenOpenSSHVersion(versionStr string) bool { + i := strings.Index(versionStr, openSSHPrefix) + if i < 0 { + return false + } + i += len(openSSHPrefix) + j := i + for ; j < len(versionStr); j++ { + if versionStr[j] < '0' || versionStr[j] > '9' { + break + } + } + version, _ := strconv.Atoi(versionStr[i:j]) + return version < 6 +} + +// autoPortListenWorkaround simulates automatic port allocation by +// trying random ports repeatedly. +func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) { + var sshListener net.Listener + var err error + const tries = 10 + for i := 0; i < tries; i++ { + addr := *laddr + addr.Port = 1024 + portRandomizer.Intn(60000) + sshListener, err = c.ListenTCP(&addr) + if err == nil { + laddr.Port = addr.Port + return sshListener, err + } + } + return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err) +} + +// RFC 4254 7.1 +type channelForwardMsg struct { + addr string + rport uint32 +} + +// ListenTCP requests the remote peer open a listening socket +// on laddr. Incoming connections will be available by calling +// Accept on the returned net.Listener. +func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { + if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) { + return c.autoPortListenWorkaround(laddr) + } + + m := channelForwardMsg{ + laddr.IP.String(), + uint32(laddr.Port), + } + // send message + ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m)) + if err != nil { + return nil, err + } + if !ok { + return nil, errors.New("ssh: tcpip-forward request denied by peer") + } + + // If the original port was 0, then the remote side will + // supply a real port number in the response. + if laddr.Port == 0 { + var p struct { + Port uint32 + } + if err := Unmarshal(resp, &p); err != nil { + return nil, err + } + laddr.Port = int(p.Port) + } + + // Register this forward, using the port number we obtained. + ch := c.forwards.add(*laddr) + + return &tcpListener{laddr, c, ch}, nil +} + +// forwardList stores a mapping between remote +// forward requests and the tcpListeners. +type forwardList struct { + sync.Mutex + entries []forwardEntry +} + +// forwardEntry represents an established mapping of a laddr on a +// remote ssh server to a channel connected to a tcpListener. +type forwardEntry struct { + laddr net.TCPAddr + c chan forward +} + +// forward represents an incoming forwarded tcpip connection. The +// arguments to add/remove/lookup should be address as specified in +// the original forward-request. +type forward struct { + newCh NewChannel // the ssh client channel underlying this forward + raddr *net.TCPAddr // the raddr of the incoming connection +} + +func (l *forwardList) add(addr net.TCPAddr) chan forward { + l.Lock() + defer l.Unlock() + f := forwardEntry{ + addr, + make(chan forward, 1), + } + l.entries = append(l.entries, f) + return f.c +} + +// See RFC 4254, section 7.2 +type forwardedTCPPayload struct { + Addr string + Port uint32 + OriginAddr string + OriginPort uint32 +} + +// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr. +func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) { + if port == 0 || port > 65535 { + return nil, fmt.Errorf("ssh: port number out of range: %d", port) + } + ip := net.ParseIP(string(addr)) + if ip == nil { + return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr) + } + return &net.TCPAddr{IP: ip, Port: int(port)}, nil +} + +func (l *forwardList) handleChannels(in <-chan NewChannel) { + for ch := range in { + var payload forwardedTCPPayload + if err := Unmarshal(ch.ExtraData(), &payload); err != nil { + ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) + continue + } + + // RFC 4254 section 7.2 specifies that incoming + // addresses should list the address, in string + // format. It is implied that this should be an IP + // address, as it would be impossible to connect to it + // otherwise. + laddr, err := parseTCPAddr(payload.Addr, payload.Port) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + + if ok := l.forward(*laddr, *raddr, ch); !ok { + // Section 7.2, implementations MUST reject spurious incoming + // connections. + ch.Reject(Prohibited, "no forward for address") + continue + } + } +} + +// remove removes the forward entry, and the channel feeding its +// listener. +func (l *forwardList) remove(addr net.TCPAddr) { + l.Lock() + defer l.Unlock() + for i, f := range l.entries { + if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port { + l.entries = append(l.entries[:i], l.entries[i+1:]...) + close(f.c) + return + } + } +} + +// closeAll closes and clears all forwards. +func (l *forwardList) closeAll() { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + close(f.c) + } + l.entries = nil +} + +func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port { + f.c <- forward{ch, &raddr} + return true + } + } + return false +} + +type tcpListener struct { + laddr *net.TCPAddr + + conn *Client + in <-chan forward +} + +// Accept waits for and returns the next connection to the listener. +func (l *tcpListener) Accept() (net.Conn, error) { + s, ok := <-l.in + if !ok { + return nil, io.EOF + } + ch, incoming, err := s.newCh.Accept() + if err != nil { + return nil, err + } + go DiscardRequests(incoming) + + return &tcpChanConn{ + Channel: ch, + laddr: l.laddr, + raddr: s.raddr, + }, nil +} + +// Close closes the listener. +func (l *tcpListener) Close() error { + m := channelForwardMsg{ + l.laddr.IP.String(), + uint32(l.laddr.Port), + } + + // this also closes the listener. + l.conn.forwards.remove(*l.laddr) + ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) + if err == nil && !ok { + err = errors.New("ssh: cancel-tcpip-forward failed") + } + return err +} + +// Addr returns the listener's network address. +func (l *tcpListener) Addr() net.Addr { + return l.laddr +} + +// Dial initiates a connection to the addr from the remote host. +// The resulting connection has a zero LocalAddr() and RemoteAddr(). +func (c *Client) Dial(n, addr string) (net.Conn, error) { + // Parse the address into host and numeric port. + host, portString, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return nil, err + } + // Use a zero address for local and remote address. + zeroAddr := &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port)) + if err != nil { + return nil, err + } + return &tcpChanConn{ + Channel: ch, + laddr: zeroAddr, + raddr: zeroAddr, + }, nil +} + +// DialTCP connects to the remote address raddr on the network net, +// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used +// as the local address for the connection. +func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) { + if laddr == nil { + laddr = &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + } + ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port) + if err != nil { + return nil, err + } + return &tcpChanConn{ + Channel: ch, + laddr: laddr, + raddr: raddr, + }, nil +} + +// RFC 4254 7.2 +type channelOpenDirectMsg struct { + raddr string + rport uint32 + laddr string + lport uint32 +} + +func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) { + msg := channelOpenDirectMsg{ + raddr: raddr, + rport: uint32(rport), + laddr: laddr, + lport: uint32(lport), + } + ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg)) + go DiscardRequests(in) + return ch, err +} + +type tcpChan struct { + Channel // the backing channel +} + +// tcpChanConn fulfills the net.Conn interface without +// the tcpChan having to hold laddr or raddr directly. +type tcpChanConn struct { + Channel + laddr, raddr net.Addr +} + +// LocalAddr returns the local network address. +func (t *tcpChanConn) LocalAddr() net.Addr { + return t.laddr +} + +// RemoteAddr returns the remote network address. +func (t *tcpChanConn) RemoteAddr() net.Addr { + return t.raddr +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. +func (t *tcpChanConn) SetDeadline(deadline time.Time) error { + if err := t.SetReadDeadline(deadline); err != nil { + return err + } + return t.SetWriteDeadline(deadline) +} + +// SetReadDeadline sets the read deadline. +// A zero value for t means Read will not time out. +// After the deadline, the error from Read will implement net.Error +// with Timeout() == true. +func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error { + return errors.New("ssh: tcpChan: deadline not supported") +} + +// SetWriteDeadline exists to satisfy the net.Conn interface +// but is not implemented by this type. It always returns an error. +func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error { + return errors.New("ssh: tcpChan: deadline not supported") +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/tcpip_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/tcpip_test.go new file mode 100644 index 000000000..f1265cb49 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/tcpip_test.go @@ -0,0 +1,20 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "testing" +) + +func TestAutoPortListenBroken(t *testing.T) { + broken := "SSH-2.0-OpenSSH_5.9hh11" + works := "SSH-2.0-OpenSSH_6.1" + if !isBrokenOpenSSHVersion(broken) { + t.Errorf("version %q not marked as broken", broken) + } + if isBrokenOpenSSHVersion(works) { + t.Errorf("version %q marked as broken", works) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/terminal.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/terminal.go new file mode 100644 index 000000000..741eeb13f --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/terminal.go @@ -0,0 +1,892 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package terminal + +import ( + "bytes" + "io" + "sync" + "unicode/utf8" +) + +// EscapeCodes contains escape sequences that can be written to the terminal in +// order to achieve different styles of text. +type EscapeCodes struct { + // Foreground colors + Black, Red, Green, Yellow, Blue, Magenta, Cyan, White []byte + + // Reset all attributes + Reset []byte +} + +var vt100EscapeCodes = EscapeCodes{ + Black: []byte{keyEscape, '[', '3', '0', 'm'}, + Red: []byte{keyEscape, '[', '3', '1', 'm'}, + Green: []byte{keyEscape, '[', '3', '2', 'm'}, + Yellow: []byte{keyEscape, '[', '3', '3', 'm'}, + Blue: []byte{keyEscape, '[', '3', '4', 'm'}, + Magenta: []byte{keyEscape, '[', '3', '5', 'm'}, + Cyan: []byte{keyEscape, '[', '3', '6', 'm'}, + White: []byte{keyEscape, '[', '3', '7', 'm'}, + + Reset: []byte{keyEscape, '[', '0', 'm'}, +} + +// Terminal contains the state for running a VT100 terminal that is capable of +// reading lines of input. +type Terminal struct { + // AutoCompleteCallback, if non-null, is called for each keypress with + // the full input line and the current position of the cursor (in + // bytes, as an index into |line|). If it returns ok=false, the key + // press is processed normally. Otherwise it returns a replacement line + // and the new cursor position. + AutoCompleteCallback func(line string, pos int, key rune) (newLine string, newPos int, ok bool) + + // Escape contains a pointer to the escape codes for this terminal. + // It's always a valid pointer, although the escape codes themselves + // may be empty if the terminal doesn't support them. + Escape *EscapeCodes + + // lock protects the terminal and the state in this object from + // concurrent processing of a key press and a Write() call. + lock sync.Mutex + + c io.ReadWriter + prompt []rune + + // line is the current line being entered. + line []rune + // pos is the logical position of the cursor in line + pos int + // echo is true if local echo is enabled + echo bool + // pasteActive is true iff there is a bracketed paste operation in + // progress. + pasteActive bool + + // cursorX contains the current X value of the cursor where the left + // edge is 0. cursorY contains the row number where the first row of + // the current line is 0. + cursorX, cursorY int + // maxLine is the greatest value of cursorY so far. + maxLine int + + termWidth, termHeight int + + // outBuf contains the terminal data to be sent. + outBuf []byte + // remainder contains the remainder of any partial key sequences after + // a read. It aliases into inBuf. + remainder []byte + inBuf [256]byte + + // history contains previously entered commands so that they can be + // accessed with the up and down keys. + history stRingBuffer + // historyIndex stores the currently accessed history entry, where zero + // means the immediately previous entry. + historyIndex int + // When navigating up and down the history it's possible to return to + // the incomplete, initial line. That value is stored in + // historyPending. + historyPending string +} + +// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is +// a local terminal, that terminal must first have been put into raw mode. +// prompt is a string that is written at the start of each input line (i.e. +// "> "). +func NewTerminal(c io.ReadWriter, prompt string) *Terminal { + return &Terminal{ + Escape: &vt100EscapeCodes, + c: c, + prompt: []rune(prompt), + termWidth: 80, + termHeight: 24, + echo: true, + historyIndex: -1, + } +} + +const ( + keyCtrlD = 4 + keyCtrlU = 21 + keyEnter = '\r' + keyEscape = 27 + keyBackspace = 127 + keyUnknown = 0xd800 /* UTF-16 surrogate area */ + iota + keyUp + keyDown + keyLeft + keyRight + keyAltLeft + keyAltRight + keyHome + keyEnd + keyDeleteWord + keyDeleteLine + keyClearScreen + keyPasteStart + keyPasteEnd +) + +var pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'} +var pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'} + +// bytesToKey tries to parse a key sequence from b. If successful, it returns +// the key and the remainder of the input. Otherwise it returns utf8.RuneError. +func bytesToKey(b []byte, pasteActive bool) (rune, []byte) { + if len(b) == 0 { + return utf8.RuneError, nil + } + + if !pasteActive { + switch b[0] { + case 1: // ^A + return keyHome, b[1:] + case 5: // ^E + return keyEnd, b[1:] + case 8: // ^H + return keyBackspace, b[1:] + case 11: // ^K + return keyDeleteLine, b[1:] + case 12: // ^L + return keyClearScreen, b[1:] + case 23: // ^W + return keyDeleteWord, b[1:] + } + } + + if b[0] != keyEscape { + if !utf8.FullRune(b) { + return utf8.RuneError, b + } + r, l := utf8.DecodeRune(b) + return r, b[l:] + } + + if !pasteActive && len(b) >= 3 && b[0] == keyEscape && b[1] == '[' { + switch b[2] { + case 'A': + return keyUp, b[3:] + case 'B': + return keyDown, b[3:] + case 'C': + return keyRight, b[3:] + case 'D': + return keyLeft, b[3:] + case 'H': + return keyHome, b[3:] + case 'F': + return keyEnd, b[3:] + } + } + + if !pasteActive && len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' { + switch b[5] { + case 'C': + return keyAltRight, b[6:] + case 'D': + return keyAltLeft, b[6:] + } + } + + if !pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteStart) { + return keyPasteStart, b[6:] + } + + if pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteEnd) { + return keyPasteEnd, b[6:] + } + + // If we get here then we have a key that we don't recognise, or a + // partial sequence. It's not clear how one should find the end of a + // sequence without knowing them all, but it seems that [a-zA-Z~] only + // appears at the end of a sequence. + for i, c := range b[0:] { + if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c == '~' { + return keyUnknown, b[i+1:] + } + } + + return utf8.RuneError, b +} + +// queue appends data to the end of t.outBuf +func (t *Terminal) queue(data []rune) { + t.outBuf = append(t.outBuf, []byte(string(data))...) +} + +var eraseUnderCursor = []rune{' ', keyEscape, '[', 'D'} +var space = []rune{' '} + +func isPrintable(key rune) bool { + isInSurrogateArea := key >= 0xd800 && key <= 0xdbff + return key >= 32 && !isInSurrogateArea +} + +// moveCursorToPos appends data to t.outBuf which will move the cursor to the +// given, logical position in the text. +func (t *Terminal) moveCursorToPos(pos int) { + if !t.echo { + return + } + + x := visualLength(t.prompt) + pos + y := x / t.termWidth + x = x % t.termWidth + + up := 0 + if y < t.cursorY { + up = t.cursorY - y + } + + down := 0 + if y > t.cursorY { + down = y - t.cursorY + } + + left := 0 + if x < t.cursorX { + left = t.cursorX - x + } + + right := 0 + if x > t.cursorX { + right = x - t.cursorX + } + + t.cursorX = x + t.cursorY = y + t.move(up, down, left, right) +} + +func (t *Terminal) move(up, down, left, right int) { + movement := make([]rune, 3*(up+down+left+right)) + m := movement + for i := 0; i < up; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'A' + m = m[3:] + } + for i := 0; i < down; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'B' + m = m[3:] + } + for i := 0; i < left; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'D' + m = m[3:] + } + for i := 0; i < right; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'C' + m = m[3:] + } + + t.queue(movement) +} + +func (t *Terminal) clearLineToRight() { + op := []rune{keyEscape, '[', 'K'} + t.queue(op) +} + +const maxLineLength = 4096 + +func (t *Terminal) setLine(newLine []rune, newPos int) { + if t.echo { + t.moveCursorToPos(0) + t.writeLine(newLine) + for i := len(newLine); i < len(t.line); i++ { + t.writeLine(space) + } + t.moveCursorToPos(newPos) + } + t.line = newLine + t.pos = newPos +} + +func (t *Terminal) advanceCursor(places int) { + t.cursorX += places + t.cursorY += t.cursorX / t.termWidth + if t.cursorY > t.maxLine { + t.maxLine = t.cursorY + } + t.cursorX = t.cursorX % t.termWidth + + if places > 0 && t.cursorX == 0 { + // Normally terminals will advance the current position + // when writing a character. But that doesn't happen + // for the last character in a line. However, when + // writing a character (except a new line) that causes + // a line wrap, the position will be advanced two + // places. + // + // So, if we are stopping at the end of a line, we + // need to write a newline so that our cursor can be + // advanced to the next line. + t.outBuf = append(t.outBuf, '\n') + } +} + +func (t *Terminal) eraseNPreviousChars(n int) { + if n == 0 { + return + } + + if t.pos < n { + n = t.pos + } + t.pos -= n + t.moveCursorToPos(t.pos) + + copy(t.line[t.pos:], t.line[n+t.pos:]) + t.line = t.line[:len(t.line)-n] + if t.echo { + t.writeLine(t.line[t.pos:]) + for i := 0; i < n; i++ { + t.queue(space) + } + t.advanceCursor(n) + t.moveCursorToPos(t.pos) + } +} + +// countToLeftWord returns then number of characters from the cursor to the +// start of the previous word. +func (t *Terminal) countToLeftWord() int { + if t.pos == 0 { + return 0 + } + + pos := t.pos - 1 + for pos > 0 { + if t.line[pos] != ' ' { + break + } + pos-- + } + for pos > 0 { + if t.line[pos] == ' ' { + pos++ + break + } + pos-- + } + + return t.pos - pos +} + +// countToRightWord returns then number of characters from the cursor to the +// start of the next word. +func (t *Terminal) countToRightWord() int { + pos := t.pos + for pos < len(t.line) { + if t.line[pos] == ' ' { + break + } + pos++ + } + for pos < len(t.line) { + if t.line[pos] != ' ' { + break + } + pos++ + } + return pos - t.pos +} + +// visualLength returns the number of visible glyphs in s. +func visualLength(runes []rune) int { + inEscapeSeq := false + length := 0 + + for _, r := range runes { + switch { + case inEscapeSeq: + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { + inEscapeSeq = false + } + case r == '\x1b': + inEscapeSeq = true + default: + length++ + } + } + + return length +} + +// handleKey processes the given key and, optionally, returns a line of text +// that the user has entered. +func (t *Terminal) handleKey(key rune) (line string, ok bool) { + if t.pasteActive && key != keyEnter { + t.addKeyToLine(key) + return + } + + switch key { + case keyBackspace: + if t.pos == 0 { + return + } + t.eraseNPreviousChars(1) + case keyAltLeft: + // move left by a word. + t.pos -= t.countToLeftWord() + t.moveCursorToPos(t.pos) + case keyAltRight: + // move right by a word. + t.pos += t.countToRightWord() + t.moveCursorToPos(t.pos) + case keyLeft: + if t.pos == 0 { + return + } + t.pos-- + t.moveCursorToPos(t.pos) + case keyRight: + if t.pos == len(t.line) { + return + } + t.pos++ + t.moveCursorToPos(t.pos) + case keyHome: + if t.pos == 0 { + return + } + t.pos = 0 + t.moveCursorToPos(t.pos) + case keyEnd: + if t.pos == len(t.line) { + return + } + t.pos = len(t.line) + t.moveCursorToPos(t.pos) + case keyUp: + entry, ok := t.history.NthPreviousEntry(t.historyIndex + 1) + if !ok { + return "", false + } + if t.historyIndex == -1 { + t.historyPending = string(t.line) + } + t.historyIndex++ + runes := []rune(entry) + t.setLine(runes, len(runes)) + case keyDown: + switch t.historyIndex { + case -1: + return + case 0: + runes := []rune(t.historyPending) + t.setLine(runes, len(runes)) + t.historyIndex-- + default: + entry, ok := t.history.NthPreviousEntry(t.historyIndex - 1) + if ok { + t.historyIndex-- + runes := []rune(entry) + t.setLine(runes, len(runes)) + } + } + case keyEnter: + t.moveCursorToPos(len(t.line)) + t.queue([]rune("\r\n")) + line = string(t.line) + ok = true + t.line = t.line[:0] + t.pos = 0 + t.cursorX = 0 + t.cursorY = 0 + t.maxLine = 0 + case keyDeleteWord: + // Delete zero or more spaces and then one or more characters. + t.eraseNPreviousChars(t.countToLeftWord()) + case keyDeleteLine: + // Delete everything from the current cursor position to the + // end of line. + for i := t.pos; i < len(t.line); i++ { + t.queue(space) + t.advanceCursor(1) + } + t.line = t.line[:t.pos] + t.moveCursorToPos(t.pos) + case keyCtrlD: + // Erase the character under the current position. + // The EOF case when the line is empty is handled in + // readLine(). + if t.pos < len(t.line) { + t.pos++ + t.eraseNPreviousChars(1) + } + case keyCtrlU: + t.eraseNPreviousChars(t.pos) + case keyClearScreen: + // Erases the screen and moves the cursor to the home position. + t.queue([]rune("\x1b[2J\x1b[H")) + t.queue(t.prompt) + t.cursorX, t.cursorY = 0, 0 + t.advanceCursor(visualLength(t.prompt)) + t.setLine(t.line, t.pos) + default: + if t.AutoCompleteCallback != nil { + prefix := string(t.line[:t.pos]) + suffix := string(t.line[t.pos:]) + + t.lock.Unlock() + newLine, newPos, completeOk := t.AutoCompleteCallback(prefix+suffix, len(prefix), key) + t.lock.Lock() + + if completeOk { + t.setLine([]rune(newLine), utf8.RuneCount([]byte(newLine)[:newPos])) + return + } + } + if !isPrintable(key) { + return + } + if len(t.line) == maxLineLength { + return + } + t.addKeyToLine(key) + } + return +} + +// addKeyToLine inserts the given key at the current position in the current +// line. +func (t *Terminal) addKeyToLine(key rune) { + if len(t.line) == cap(t.line) { + newLine := make([]rune, len(t.line), 2*(1+len(t.line))) + copy(newLine, t.line) + t.line = newLine + } + t.line = t.line[:len(t.line)+1] + copy(t.line[t.pos+1:], t.line[t.pos:]) + t.line[t.pos] = key + if t.echo { + t.writeLine(t.line[t.pos:]) + } + t.pos++ + t.moveCursorToPos(t.pos) +} + +func (t *Terminal) writeLine(line []rune) { + for len(line) != 0 { + remainingOnLine := t.termWidth - t.cursorX + todo := len(line) + if todo > remainingOnLine { + todo = remainingOnLine + } + t.queue(line[:todo]) + t.advanceCursor(visualLength(line[:todo])) + line = line[todo:] + } +} + +func (t *Terminal) Write(buf []byte) (n int, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + if t.cursorX == 0 && t.cursorY == 0 { + // This is the easy case: there's nothing on the screen that we + // have to move out of the way. + return t.c.Write(buf) + } + + // We have a prompt and possibly user input on the screen. We + // have to clear it first. + t.move(0 /* up */, 0 /* down */, t.cursorX /* left */, 0 /* right */) + t.cursorX = 0 + t.clearLineToRight() + + for t.cursorY > 0 { + t.move(1 /* up */, 0, 0, 0) + t.cursorY-- + t.clearLineToRight() + } + + if _, err = t.c.Write(t.outBuf); err != nil { + return + } + t.outBuf = t.outBuf[:0] + + if n, err = t.c.Write(buf); err != nil { + return + } + + t.writeLine(t.prompt) + if t.echo { + t.writeLine(t.line) + } + + t.moveCursorToPos(t.pos) + + if _, err = t.c.Write(t.outBuf); err != nil { + return + } + t.outBuf = t.outBuf[:0] + return +} + +// ReadPassword temporarily changes the prompt and reads a password, without +// echo, from the terminal. +func (t *Terminal) ReadPassword(prompt string) (line string, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + oldPrompt := t.prompt + t.prompt = []rune(prompt) + t.echo = false + + line, err = t.readLine() + + t.prompt = oldPrompt + t.echo = true + + return +} + +// ReadLine returns a line of input from the terminal. +func (t *Terminal) ReadLine() (line string, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + return t.readLine() +} + +func (t *Terminal) readLine() (line string, err error) { + // t.lock must be held at this point + + if t.cursorX == 0 && t.cursorY == 0 { + t.writeLine(t.prompt) + t.c.Write(t.outBuf) + t.outBuf = t.outBuf[:0] + } + + lineIsPasted := t.pasteActive + + for { + rest := t.remainder + lineOk := false + for !lineOk { + var key rune + key, rest = bytesToKey(rest, t.pasteActive) + if key == utf8.RuneError { + break + } + if !t.pasteActive { + if key == keyCtrlD { + if len(t.line) == 0 { + return "", io.EOF + } + } + if key == keyPasteStart { + t.pasteActive = true + if len(t.line) == 0 { + lineIsPasted = true + } + continue + } + } else if key == keyPasteEnd { + t.pasteActive = false + continue + } + if !t.pasteActive { + lineIsPasted = false + } + line, lineOk = t.handleKey(key) + } + if len(rest) > 0 { + n := copy(t.inBuf[:], rest) + t.remainder = t.inBuf[:n] + } else { + t.remainder = nil + } + t.c.Write(t.outBuf) + t.outBuf = t.outBuf[:0] + if lineOk { + if t.echo { + t.historyIndex = -1 + t.history.Add(line) + } + if lineIsPasted { + err = ErrPasteIndicator + } + return + } + + // t.remainder is a slice at the beginning of t.inBuf + // containing a partial key sequence + readBuf := t.inBuf[len(t.remainder):] + var n int + + t.lock.Unlock() + n, err = t.c.Read(readBuf) + t.lock.Lock() + + if err != nil { + return + } + + t.remainder = t.inBuf[:n+len(t.remainder)] + } + + panic("unreachable") // for Go 1.0. +} + +// SetPrompt sets the prompt to be used when reading subsequent lines. +func (t *Terminal) SetPrompt(prompt string) { + t.lock.Lock() + defer t.lock.Unlock() + + t.prompt = []rune(prompt) +} + +func (t *Terminal) clearAndRepaintLinePlusNPrevious(numPrevLines int) { + // Move cursor to column zero at the start of the line. + t.move(t.cursorY, 0, t.cursorX, 0) + t.cursorX, t.cursorY = 0, 0 + t.clearLineToRight() + for t.cursorY < numPrevLines { + // Move down a line + t.move(0, 1, 0, 0) + t.cursorY++ + t.clearLineToRight() + } + // Move back to beginning. + t.move(t.cursorY, 0, 0, 0) + t.cursorX, t.cursorY = 0, 0 + + t.queue(t.prompt) + t.advanceCursor(visualLength(t.prompt)) + t.writeLine(t.line) + t.moveCursorToPos(t.pos) +} + +func (t *Terminal) SetSize(width, height int) error { + t.lock.Lock() + defer t.lock.Unlock() + + if width == 0 { + width = 1 + } + + oldWidth := t.termWidth + t.termWidth, t.termHeight = width, height + + switch { + case width == oldWidth: + // If the width didn't change then nothing else needs to be + // done. + return nil + case len(t.line) == 0 && t.cursorX == 0 && t.cursorY == 0: + // If there is nothing on current line and no prompt printed, + // just do nothing + return nil + case width < oldWidth: + // Some terminals (e.g. xterm) will truncate lines that were + // too long when shinking. Others, (e.g. gnome-terminal) will + // attempt to wrap them. For the former, repainting t.maxLine + // works great, but that behaviour goes badly wrong in the case + // of the latter because they have doubled every full line. + + // We assume that we are working on a terminal that wraps lines + // and adjust the cursor position based on every previous line + // wrapping and turning into two. This causes the prompt on + // xterms to move upwards, which isn't great, but it avoids a + // huge mess with gnome-terminal. + if t.cursorX >= t.termWidth { + t.cursorX = t.termWidth - 1 + } + t.cursorY *= 2 + t.clearAndRepaintLinePlusNPrevious(t.maxLine * 2) + case width > oldWidth: + // If the terminal expands then our position calculations will + // be wrong in the future because we think the cursor is + // |t.pos| chars into the string, but there will be a gap at + // the end of any wrapped line. + // + // But the position will actually be correct until we move, so + // we can move back to the beginning and repaint everything. + t.clearAndRepaintLinePlusNPrevious(t.maxLine) + } + + _, err := t.c.Write(t.outBuf) + t.outBuf = t.outBuf[:0] + return err +} + +type pasteIndicatorError struct{} + +func (pasteIndicatorError) Error() string { + return "terminal: ErrPasteIndicator not correctly handled" +} + +// ErrPasteIndicator may be returned from ReadLine as the error, in addition +// to valid line data. It indicates that bracketed paste mode is enabled and +// that the returned line consists only of pasted data. Programs may wish to +// interpret pasted data more literally than typed data. +var ErrPasteIndicator = pasteIndicatorError{} + +// SetBracketedPasteMode requests that the terminal bracket paste operations +// with markers. Not all terminals support this but, if it is supported, then +// enabling this mode will stop any autocomplete callback from running due to +// pastes. Additionally, any lines that are completely pasted will be returned +// from ReadLine with the error set to ErrPasteIndicator. +func (t *Terminal) SetBracketedPasteMode(on bool) { + if on { + io.WriteString(t.c, "\x1b[?2004h") + } else { + io.WriteString(t.c, "\x1b[?2004l") + } +} + +// stRingBuffer is a ring buffer of strings. +type stRingBuffer struct { + // entries contains max elements. + entries []string + max int + // head contains the index of the element most recently added to the ring. + head int + // size contains the number of elements in the ring. + size int +} + +func (s *stRingBuffer) Add(a string) { + if s.entries == nil { + const defaultNumEntries = 100 + s.entries = make([]string, defaultNumEntries) + s.max = defaultNumEntries + } + + s.head = (s.head + 1) % s.max + s.entries[s.head] = a + if s.size < s.max { + s.size++ + } +} + +// NthPreviousEntry returns the value passed to the nth previous call to Add. +// If n is zero then the immediately prior value is returned, if one, then the +// next most recent, and so on. If such an element doesn't exist then ok is +// false. +func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) { + if n >= s.size { + return "", false + } + index := s.head - n + if index < 0 { + index += s.max + } + return s.entries[index], true +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/terminal_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/terminal_test.go new file mode 100644 index 000000000..a663fe41b --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/terminal_test.go @@ -0,0 +1,269 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package terminal + +import ( + "io" + "testing" +) + +type MockTerminal struct { + toSend []byte + bytesPerRead int + received []byte +} + +func (c *MockTerminal) Read(data []byte) (n int, err error) { + n = len(data) + if n == 0 { + return + } + if n > len(c.toSend) { + n = len(c.toSend) + } + if n == 0 { + return 0, io.EOF + } + if c.bytesPerRead > 0 && n > c.bytesPerRead { + n = c.bytesPerRead + } + copy(data, c.toSend[:n]) + c.toSend = c.toSend[n:] + return +} + +func (c *MockTerminal) Write(data []byte) (n int, err error) { + c.received = append(c.received, data...) + return len(data), nil +} + +func TestClose(t *testing.T) { + c := &MockTerminal{} + ss := NewTerminal(c, "> ") + line, err := ss.ReadLine() + if line != "" { + t.Errorf("Expected empty line but got: %s", line) + } + if err != io.EOF { + t.Errorf("Error should have been EOF but got: %s", err) + } +} + +var keyPressTests = []struct { + in string + line string + err error + throwAwayLines int +}{ + { + err: io.EOF, + }, + { + in: "\r", + line: "", + }, + { + in: "foo\r", + line: "foo", + }, + { + in: "a\x1b[Cb\r", // right + line: "ab", + }, + { + in: "a\x1b[Db\r", // left + line: "ba", + }, + { + in: "a\177b\r", // backspace + line: "b", + }, + { + in: "\x1b[A\r", // up + }, + { + in: "\x1b[B\r", // down + }, + { + in: "line\x1b[A\x1b[B\r", // up then down + line: "line", + }, + { + in: "line1\rline2\x1b[A\r", // recall previous line. + line: "line1", + throwAwayLines: 1, + }, + { + // recall two previous lines and append. + in: "line1\rline2\rline3\x1b[A\x1b[Axxx\r", + line: "line1xxx", + throwAwayLines: 2, + }, + { + // Ctrl-A to move to beginning of line followed by ^K to kill + // line. + in: "a b \001\013\r", + line: "", + }, + { + // Ctrl-A to move to beginning of line, Ctrl-E to move to end, + // finally ^K to kill nothing. + in: "a b \001\005\013\r", + line: "a b ", + }, + { + in: "\027\r", + line: "", + }, + { + in: "a\027\r", + line: "", + }, + { + in: "a \027\r", + line: "", + }, + { + in: "a b\027\r", + line: "a ", + }, + { + in: "a b \027\r", + line: "a ", + }, + { + in: "one two thr\x1b[D\027\r", + line: "one two r", + }, + { + in: "\013\r", + line: "", + }, + { + in: "a\013\r", + line: "a", + }, + { + in: "ab\x1b[D\013\r", + line: "a", + }, + { + in: "Ξεσκεπάζω\r", + line: "Ξεσκεπάζω", + }, + { + in: "£\r\x1b[A\177\r", // non-ASCII char, enter, up, backspace. + line: "", + throwAwayLines: 1, + }, + { + in: "£\r££\x1b[A\x1b[B\177\r", // non-ASCII char, enter, 2x non-ASCII, up, down, backspace, enter. + line: "£", + throwAwayLines: 1, + }, + { + // Ctrl-D at the end of the line should be ignored. + in: "a\004\r", + line: "a", + }, + { + // a, b, left, Ctrl-D should erase the b. + in: "ab\x1b[D\004\r", + line: "a", + }, + { + // a, b, c, d, left, left, ^U should erase to the beginning of + // the line. + in: "abcd\x1b[D\x1b[D\025\r", + line: "cd", + }, + { + // Bracketed paste mode: control sequences should be returned + // verbatim in paste mode. + in: "abc\x1b[200~de\177f\x1b[201~\177\r", + line: "abcde\177", + }, + { + // Enter in bracketed paste mode should still work. + in: "abc\x1b[200~d\refg\x1b[201~h\r", + line: "efgh", + throwAwayLines: 1, + }, + { + // Lines consisting entirely of pasted data should be indicated as such. + in: "\x1b[200~a\r", + line: "a", + err: ErrPasteIndicator, + }, +} + +func TestKeyPresses(t *testing.T) { + for i, test := range keyPressTests { + for j := 1; j < len(test.in); j++ { + c := &MockTerminal{ + toSend: []byte(test.in), + bytesPerRead: j, + } + ss := NewTerminal(c, "> ") + for k := 0; k < test.throwAwayLines; k++ { + _, err := ss.ReadLine() + if err != nil { + t.Errorf("Throwaway line %d from test %d resulted in error: %s", k, i, err) + } + } + line, err := ss.ReadLine() + if line != test.line { + t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line) + break + } + if err != test.err { + t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err) + break + } + } + } +} + +func TestPasswordNotSaved(t *testing.T) { + c := &MockTerminal{ + toSend: []byte("password\r\x1b[A\r"), + bytesPerRead: 1, + } + ss := NewTerminal(c, "> ") + pw, _ := ss.ReadPassword("> ") + if pw != "password" { + t.Fatalf("failed to read password, got %s", pw) + } + line, _ := ss.ReadLine() + if len(line) > 0 { + t.Fatalf("password was saved in history") + } +} + +var setSizeTests = []struct { + width, height int +}{ + {40, 13}, + {80, 24}, + {132, 43}, +} + +func TestTerminalSetSize(t *testing.T) { + for _, setSize := range setSizeTests { + c := &MockTerminal{ + toSend: []byte("password\r\x1b[A\r"), + bytesPerRead: 1, + } + ss := NewTerminal(c, "> ") + ss.SetSize(setSize.width, setSize.height) + pw, _ := ss.ReadPassword("Password: ") + if pw != "password" { + t.Fatalf("failed to read password, got %s", pw) + } + if string(c.received) != "Password: \r\n" { + t.Errorf("failed to set the temporary prompt expected %q, got %q", "Password: ", c.received) + } + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/util.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/util.go new file mode 100644 index 000000000..598e3df77 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/util.go @@ -0,0 +1,128 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux,!appengine netbsd openbsd + +// Package terminal provides support functions for dealing with terminals, as +// commonly found on UNIX systems. +// +// Putting a terminal into raw mode is the most common requirement: +// +// oldState, err := terminal.MakeRaw(0) +// if err != nil { +// panic(err) +// } +// defer terminal.Restore(0, oldState) +package terminal // import "golang.org/x/crypto/ssh/terminal" + +import ( + "io" + "syscall" + "unsafe" +) + +// State contains the state of a terminal. +type State struct { + termios syscall.Termios +} + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + var termios syscall.Termios + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) + return err == 0 +} + +// MakeRaw put the terminal connected to the given file descriptor into raw +// mode and returns the previous state of the terminal so that it can be +// restored. +func MakeRaw(fd int) (*State, error) { + var oldState State + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { + return nil, err + } + + newState := oldState.termios + newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF + newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { + return nil, err + } + + return &oldState, nil +} + +// GetState returns the current state of a terminal which may be useful to +// restore the terminal after a signal. +func GetState(fd int) (*State, error) { + var oldState State + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { + return nil, err + } + + return &oldState, nil +} + +// Restore restores the terminal connected to the given file descriptor to a +// previous state. +func Restore(fd int, state *State) error { + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0) + return err +} + +// GetSize returns the dimensions of the given terminal. +func GetSize(fd int) (width, height int, err error) { + var dimensions [4]uint16 + + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), uintptr(syscall.TIOCGWINSZ), uintptr(unsafe.Pointer(&dimensions)), 0, 0, 0); err != 0 { + return -1, -1, err + } + return int(dimensions[1]), int(dimensions[0]), nil +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + var oldState syscall.Termios + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0); err != 0 { + return nil, err + } + + newState := oldState + newState.Lflag &^= syscall.ECHO + newState.Lflag |= syscall.ICANON | syscall.ISIG + newState.Iflag |= syscall.ICRNL + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { + return nil, err + } + + defer func() { + syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0) + }() + + var buf [16]byte + var ret []byte + for { + n, err := syscall.Read(fd, buf[:]) + if err != nil { + return nil, err + } + if n == 0 { + if len(ret) == 0 { + return nil, io.EOF + } + break + } + if buf[n-1] == '\n' { + n-- + } + ret = append(ret, buf[:n]...) + if n < len(buf) { + break + } + } + + return ret, nil +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/util_bsd.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/util_bsd.go new file mode 100644 index 000000000..9c1ffd145 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/util_bsd.go @@ -0,0 +1,12 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd netbsd openbsd + +package terminal + +import "syscall" + +const ioctlReadTermios = syscall.TIOCGETA +const ioctlWriteTermios = syscall.TIOCSETA diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/util_linux.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/util_linux.go new file mode 100644 index 000000000..5883b22d7 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/util_linux.go @@ -0,0 +1,11 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package terminal + +// These constants are declared here, rather than importing +// them from the syscall package as some syscall packages, even +// on linux, for example gccgo, do not declare them. +const ioctlReadTermios = 0x5401 // syscall.TCGETS +const ioctlWriteTermios = 0x5402 // syscall.TCSETS diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/util_windows.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/util_windows.go new file mode 100644 index 000000000..2dd6c3d97 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/terminal/util_windows.go @@ -0,0 +1,174 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +// Package terminal provides support functions for dealing with terminals, as +// commonly found on UNIX systems. +// +// Putting a terminal into raw mode is the most common requirement: +// +// oldState, err := terminal.MakeRaw(0) +// if err != nil { +// panic(err) +// } +// defer terminal.Restore(0, oldState) +package terminal + +import ( + "io" + "syscall" + "unsafe" +) + +const ( + enableLineInput = 2 + enableEchoInput = 4 + enableProcessedInput = 1 + enableWindowInput = 8 + enableMouseInput = 16 + enableInsertMode = 32 + enableQuickEditMode = 64 + enableExtendedFlags = 128 + enableAutoPosition = 256 + enableProcessedOutput = 1 + enableWrapAtEolOutput = 2 +) + +var kernel32 = syscall.NewLazyDLL("kernel32.dll") + +var ( + procGetConsoleMode = kernel32.NewProc("GetConsoleMode") + procSetConsoleMode = kernel32.NewProc("SetConsoleMode") + procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") +) + +type ( + short int16 + word uint16 + + coord struct { + x short + y short + } + smallRect struct { + left short + top short + right short + bottom short + } + consoleScreenBufferInfo struct { + size coord + cursorPosition coord + attributes word + window smallRect + maximumWindowSize coord + } +) + +type State struct { + mode uint32 +} + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + var st uint32 + r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + return r != 0 && e == 0 +} + +// MakeRaw put the terminal connected to the given file descriptor into raw +// mode and returns the previous state of the terminal so that it can be +// restored. +func MakeRaw(fd int) (*State, error) { + var st uint32 + _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + if e != 0 { + return nil, error(e) + } + st &^= (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput) + _, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0) + if e != 0 { + return nil, error(e) + } + return &State{st}, nil +} + +// GetState returns the current state of a terminal which may be useful to +// restore the terminal after a signal. +func GetState(fd int) (*State, error) { + var st uint32 + _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + if e != 0 { + return nil, error(e) + } + return &State{st}, nil +} + +// Restore restores the terminal connected to the given file descriptor to a +// previous state. +func Restore(fd int, state *State) error { + _, _, err := syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(state.mode), 0) + return err +} + +// GetSize returns the dimensions of the given terminal. +func GetSize(fd int) (width, height int, err error) { + var info consoleScreenBufferInfo + _, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&info)), 0) + if e != 0 { + return 0, 0, error(e) + } + return int(info.size.x), int(info.size.y), nil +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + var st uint32 + _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + if e != 0 { + return nil, error(e) + } + old := st + + st &^= (enableEchoInput) + st |= (enableProcessedInput | enableLineInput | enableProcessedOutput) + _, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0) + if e != 0 { + return nil, error(e) + } + + defer func() { + syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0) + }() + + var buf [16]byte + var ret []byte + for { + n, err := syscall.Read(syscall.Handle(fd), buf[:]) + if err != nil { + return nil, err + } + if n == 0 { + if len(ret) == 0 { + return nil, io.EOF + } + break + } + if buf[n-1] == '\n' { + n-- + } + if n > 0 && buf[n-1] == '\r' { + n-- + } + ret = append(ret, buf[:n]...) + if n < len(buf) { + break + } + } + + return ret, nil +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/agent_unix_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/agent_unix_test.go new file mode 100644 index 000000000..502e24feb --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/agent_unix_test.go @@ -0,0 +1,50 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package test + +import ( + "bytes" + "testing" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +func TestAgentForward(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + keyring := agent.NewKeyring() + keyring.Add(testPrivateKeys["dsa"], nil, "") + pub := testPublicKeys["dsa"] + + sess, err := conn.NewSession() + if err != nil { + t.Fatalf("NewSession: %v", err) + } + if err := agent.RequestAgentForwarding(sess); err != nil { + t.Fatalf("RequestAgentForwarding: %v", err) + } + + if err := agent.ForwardToAgent(conn, keyring); err != nil { + t.Fatalf("SetupForwardKeyring: %v", err) + } + out, err := sess.CombinedOutput("ssh-add -L") + if err != nil { + t.Fatalf("running ssh-add: %v, out %s", err, out) + } + key, _, _, _, err := ssh.ParseAuthorizedKey(out) + if err != nil { + t.Fatalf("ParseAuthorizedKey(%q): %v", out, err) + } + + if !bytes.Equal(key.Marshal(), pub.Marshal()) { + t.Fatalf("got key %s, want %s", ssh.MarshalAuthorizedKey(key), ssh.MarshalAuthorizedKey(pub)) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/cert_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/cert_test.go new file mode 100644 index 000000000..364790f17 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/cert_test.go @@ -0,0 +1,47 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package test + +import ( + "crypto/rand" + "testing" + + "golang.org/x/crypto/ssh" +) + +func TestCertLogin(t *testing.T) { + s := newServer(t) + defer s.Shutdown() + + // Use a key different from the default. + clientKey := testSigners["dsa"] + caAuthKey := testSigners["ecdsa"] + cert := &ssh.Certificate{ + Key: clientKey.PublicKey(), + ValidPrincipals: []string{username()}, + CertType: ssh.UserCert, + ValidBefore: ssh.CertTimeInfinity, + } + if err := cert.SignCert(rand.Reader, caAuthKey); err != nil { + t.Fatalf("SetSignature: %v", err) + } + + certSigner, err := ssh.NewCertSigner(cert, clientKey) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + conf := &ssh.ClientConfig{ + User: username(), + } + conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner)) + client, err := s.TryDial(conf) + if err != nil { + t.Fatalf("TryDial: %v", err) + } + client.Close() +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/doc.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/doc.go new file mode 100644 index 000000000..3f9b3346d --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/doc.go @@ -0,0 +1,7 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package contains integration tests for the +// golang.org/x/crypto/ssh package. +package test // import "golang.org/x/crypto/ssh/test" diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/forward_unix_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/forward_unix_test.go new file mode 100644 index 000000000..877a88cde --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/forward_unix_test.go @@ -0,0 +1,160 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package test + +import ( + "bytes" + "io" + "io/ioutil" + "math/rand" + "net" + "testing" + "time" +) + +func TestPortForward(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + sshListener, err := conn.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + go func() { + sshConn, err := sshListener.Accept() + if err != nil { + t.Fatalf("listen.Accept failed: %v", err) + } + + _, err = io.Copy(sshConn, sshConn) + if err != nil && err != io.EOF { + t.Fatalf("ssh client copy: %v", err) + } + sshConn.Close() + }() + + forwardedAddr := sshListener.Addr().String() + tcpConn, err := net.Dial("tcp", forwardedAddr) + if err != nil { + t.Fatalf("TCP dial failed: %v", err) + } + + readChan := make(chan []byte) + go func() { + data, _ := ioutil.ReadAll(tcpConn) + readChan <- data + }() + + // Invent some data. + data := make([]byte, 100*1000) + for i := range data { + data[i] = byte(i % 255) + } + + var sent []byte + for len(sent) < 1000*1000 { + // Send random sized chunks + m := rand.Intn(len(data)) + n, err := tcpConn.Write(data[:m]) + if err != nil { + break + } + sent = append(sent, data[:n]...) + } + if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil { + t.Errorf("tcpConn.CloseWrite: %v", err) + } + + read := <-readChan + + if len(sent) != len(read) { + t.Fatalf("got %d bytes, want %d", len(read), len(sent)) + } + if bytes.Compare(sent, read) != 0 { + t.Fatalf("read back data does not match") + } + + if err := sshListener.Close(); err != nil { + t.Fatalf("sshListener.Close: %v", err) + } + + // Check that the forward disappeared. + tcpConn, err = net.Dial("tcp", forwardedAddr) + if err == nil { + tcpConn.Close() + t.Errorf("still listening to %s after closing", forwardedAddr) + } +} + +func TestAcceptClose(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + + sshListener, err := conn.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + quit := make(chan error, 1) + go func() { + for { + c, err := sshListener.Accept() + if err != nil { + quit <- err + break + } + c.Close() + } + }() + sshListener.Close() + + select { + case <-time.After(1 * time.Second): + t.Errorf("timeout: listener did not close.") + case err := <-quit: + t.Logf("quit as expected (error %v)", err) + } +} + +// Check that listeners exit if the underlying client transport dies. +func TestPortForwardConnectionClose(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + + sshListener, err := conn.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + quit := make(chan error, 1) + go func() { + for { + c, err := sshListener.Accept() + if err != nil { + quit <- err + break + } + c.Close() + } + }() + + // It would be even nicer if we closed the server side, but it + // is more involved as the fd for that side is dup()ed. + server.clientConn.Close() + + select { + case <-time.After(1 * time.Second): + t.Errorf("timeout: listener did not close.") + case err := <-quit: + t.Logf("quit as expected (error %v)", err) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/session_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/session_test.go new file mode 100644 index 000000000..fbd1044f5 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/session_test.go @@ -0,0 +1,321 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !windows + +package test + +// Session functional tests. + +import ( + "bytes" + "errors" + "io" + "strings" + "testing" + + "golang.org/x/crypto/ssh" +) + +func TestRunCommandSuccess(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + err = session.Run("true") + if err != nil { + t.Fatalf("session failed: %v", err) + } +} + +func TestHostKeyCheck(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + + conf := clientConfig() + hostDB := hostKeyDB() + conf.HostKeyCallback = hostDB.Check + + // change the keys. + hostDB.keys[ssh.KeyAlgoRSA][25]++ + hostDB.keys[ssh.KeyAlgoDSA][25]++ + hostDB.keys[ssh.KeyAlgoECDSA256][25]++ + + conn, err := server.TryDial(conf) + if err == nil { + conn.Close() + t.Fatalf("dial should have failed.") + } else if !strings.Contains(err.Error(), "host key mismatch") { + t.Fatalf("'host key mismatch' not found in %v", err) + } +} + +func TestRunCommandStdin(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + r, w := io.Pipe() + defer r.Close() + defer w.Close() + session.Stdin = r + + err = session.Run("true") + if err != nil { + t.Fatalf("session failed: %v", err) + } +} + +func TestRunCommandStdinError(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + r, w := io.Pipe() + defer r.Close() + session.Stdin = r + pipeErr := errors.New("closing write end of pipe") + w.CloseWithError(pipeErr) + + err = session.Run("true") + if err != pipeErr { + t.Fatalf("expected %v, found %v", pipeErr, err) + } +} + +func TestRunCommandFailed(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + err = session.Run(`bash -c "kill -9 $$"`) + if err == nil { + t.Fatalf("session succeeded: %v", err) + } +} + +func TestRunCommandWeClosed(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + err = session.Shell() + if err != nil { + t.Fatalf("shell failed: %v", err) + } + err = session.Close() + if err != nil { + t.Fatalf("shell failed: %v", err) + } +} + +func TestFuncLargeRead(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("unable to create new session: %s", err) + } + + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("unable to acquire stdout pipe: %s", err) + } + + err = session.Start("dd if=/dev/urandom bs=2048 count=1024") + if err != nil { + t.Fatalf("unable to execute remote command: %s", err) + } + + buf := new(bytes.Buffer) + n, err := io.Copy(buf, stdout) + if err != nil { + t.Fatalf("error reading from remote stdout: %s", err) + } + + if n != 2048*1024 { + t.Fatalf("Expected %d bytes but read only %d from remote command", 2048, n) + } +} + +func TestKeyChange(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + hostDB := hostKeyDB() + conf.HostKeyCallback = hostDB.Check + conf.RekeyThreshold = 1024 + conn := server.Dial(conf) + defer conn.Close() + + for i := 0; i < 4; i++ { + session, err := conn.NewSession() + if err != nil { + t.Fatalf("unable to create new session: %s", err) + } + + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("unable to acquire stdout pipe: %s", err) + } + + err = session.Start("dd if=/dev/urandom bs=1024 count=1") + if err != nil { + t.Fatalf("unable to execute remote command: %s", err) + } + buf := new(bytes.Buffer) + n, err := io.Copy(buf, stdout) + if err != nil { + t.Fatalf("error reading from remote stdout: %s", err) + } + + want := int64(1024) + if n != want { + t.Fatalf("Expected %d bytes but read only %d from remote command", want, n) + } + } + + if changes := hostDB.checkCount; changes < 4 { + t.Errorf("got %d key changes, want 4", changes) + } +} + +func TestInvalidTerminalMode(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + if err = session.RequestPty("vt100", 80, 40, ssh.TerminalModes{255: 1984}); err == nil { + t.Fatalf("req-pty failed: successful request with invalid mode") + } +} + +func TestValidTerminalMode(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("unable to acquire stdout pipe: %s", err) + } + + stdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("unable to acquire stdin pipe: %s", err) + } + + tm := ssh.TerminalModes{ssh.ECHO: 0} + if err = session.RequestPty("xterm", 80, 40, tm); err != nil { + t.Fatalf("req-pty failed: %s", err) + } + + err = session.Shell() + if err != nil { + t.Fatalf("session failed: %s", err) + } + + stdin.Write([]byte("stty -a && exit\n")) + + var buf bytes.Buffer + if _, err := io.Copy(&buf, stdout); err != nil { + t.Fatalf("reading failed: %s", err) + } + + if sttyOutput := buf.String(); !strings.Contains(sttyOutput, "-echo ") { + t.Fatalf("terminal mode failure: expected -echo in stty output, got %s", sttyOutput) + } +} + +func TestCiphers(t *testing.T) { + var config ssh.Config + config.SetDefaults() + cipherOrder := config.Ciphers + // This cipher will not be tested when commented out in cipher.go it will + // fallback to the next available as per line 292. + cipherOrder = append(cipherOrder, "aes128-cbc") + + for _, ciph := range cipherOrder { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + conf.Ciphers = []string{ciph} + // Don't fail if sshd doesnt have the cipher. + conf.Ciphers = append(conf.Ciphers, cipherOrder...) + conn, err := server.TryDial(conf) + if err == nil { + conn.Close() + } else { + t.Fatalf("failed for cipher %q", ciph) + } + } +} + +func TestMACs(t *testing.T) { + var config ssh.Config + config.SetDefaults() + macOrder := config.MACs + + for _, mac := range macOrder { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + conf.MACs = []string{mac} + // Don't fail if sshd doesnt have the MAC. + conf.MACs = append(conf.MACs, macOrder...) + if conn, err := server.TryDial(conf); err == nil { + conn.Close() + } else { + t.Fatalf("failed for MAC %q", mac) + } + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/tcpip_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/tcpip_test.go new file mode 100644 index 000000000..a2eb9358d --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/tcpip_test.go @@ -0,0 +1,46 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !windows + +package test + +// direct-tcpip functional tests + +import ( + "io" + "net" + "testing" +) + +func TestDial(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + sshConn := server.Dial(clientConfig()) + defer sshConn.Close() + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer l.Close() + + go func() { + for { + c, err := l.Accept() + if err != nil { + break + } + + io.WriteString(c, c.RemoteAddr().String()) + c.Close() + } + }() + + conn, err := sshConn.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/test_unix_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/test_unix_test.go new file mode 100644 index 000000000..f1fc50b2e --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/test_unix_test.go @@ -0,0 +1,261 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd plan9 + +package test + +// functional test harness for unix. + +import ( + "bytes" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "os/user" + "path/filepath" + "testing" + "text/template" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/testdata" +) + +const sshd_config = ` +Protocol 2 +HostKey {{.Dir}}/id_rsa +HostKey {{.Dir}}/id_dsa +HostKey {{.Dir}}/id_ecdsa +Pidfile {{.Dir}}/sshd.pid +#UsePrivilegeSeparation no +KeyRegenerationInterval 3600 +ServerKeyBits 768 +SyslogFacility AUTH +LogLevel DEBUG2 +LoginGraceTime 120 +PermitRootLogin no +StrictModes no +RSAAuthentication yes +PubkeyAuthentication yes +AuthorizedKeysFile {{.Dir}}/id_user.pub +TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub +IgnoreRhosts yes +RhostsRSAAuthentication no +HostbasedAuthentication no +` + +var configTmpl = template.Must(template.New("").Parse(sshd_config)) + +type server struct { + t *testing.T + cleanup func() // executed during Shutdown + configfile string + cmd *exec.Cmd + output bytes.Buffer // holds stderr from sshd process + + // Client half of the network connection. + clientConn net.Conn +} + +func username() string { + var username string + if user, err := user.Current(); err == nil { + username = user.Username + } else { + // user.Current() currently requires cgo. If an error is + // returned attempt to get the username from the environment. + log.Printf("user.Current: %v; falling back on $USER", err) + username = os.Getenv("USER") + } + if username == "" { + panic("Unable to get username") + } + return username +} + +type storedHostKey struct { + // keys map from an algorithm string to binary key data. + keys map[string][]byte + + // checkCount counts the Check calls. Used for testing + // rekeying. + checkCount int +} + +func (k *storedHostKey) Add(key ssh.PublicKey) { + if k.keys == nil { + k.keys = map[string][]byte{} + } + k.keys[key.Type()] = key.Marshal() +} + +func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error { + k.checkCount++ + algo := key.Type() + + if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 { + return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo]) + } + return nil +} + +func hostKeyDB() *storedHostKey { + keyChecker := &storedHostKey{} + keyChecker.Add(testPublicKeys["ecdsa"]) + keyChecker.Add(testPublicKeys["rsa"]) + keyChecker.Add(testPublicKeys["dsa"]) + return keyChecker +} + +func clientConfig() *ssh.ClientConfig { + config := &ssh.ClientConfig{ + User: username(), + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(testSigners["user"]), + }, + HostKeyCallback: hostKeyDB().Check, + } + return config +} + +// unixConnection creates two halves of a connected net.UnixConn. It +// is used for connecting the Go SSH client with sshd without opening +// ports. +func unixConnection() (*net.UnixConn, *net.UnixConn, error) { + dir, err := ioutil.TempDir("", "unixConnection") + if err != nil { + return nil, nil, err + } + defer os.Remove(dir) + + addr := filepath.Join(dir, "ssh") + listener, err := net.Listen("unix", addr) + if err != nil { + return nil, nil, err + } + defer listener.Close() + c1, err := net.Dial("unix", addr) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1.(*net.UnixConn), c2.(*net.UnixConn), nil +} + +func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) { + sshd, err := exec.LookPath("sshd") + if err != nil { + s.t.Skipf("skipping test: %v", err) + } + + c1, c2, err := unixConnection() + if err != nil { + s.t.Fatalf("unixConnection: %v", err) + } + + s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e") + f, err := c2.File() + if err != nil { + s.t.Fatalf("UnixConn.File: %v", err) + } + defer f.Close() + s.cmd.Stdin = f + s.cmd.Stdout = f + s.cmd.Stderr = &s.output + if err := s.cmd.Start(); err != nil { + s.t.Fail() + s.Shutdown() + s.t.Fatalf("s.cmd.Start: %v", err) + } + s.clientConn = c1 + conn, chans, reqs, err := ssh.NewClientConn(c1, "", config) + if err != nil { + return nil, err + } + return ssh.NewClient(conn, chans, reqs), nil +} + +func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client { + conn, err := s.TryDial(config) + if err != nil { + s.t.Fail() + s.Shutdown() + s.t.Fatalf("ssh.Client: %v", err) + } + return conn +} + +func (s *server) Shutdown() { + if s.cmd != nil && s.cmd.Process != nil { + // Don't check for errors; if it fails it's most + // likely "os: process already finished", and we don't + // care about that. Use os.Interrupt, so child + // processes are killed too. + s.cmd.Process.Signal(os.Interrupt) + s.cmd.Wait() + } + if s.t.Failed() { + // log any output from sshd process + s.t.Logf("sshd: %s", s.output.String()) + } + s.cleanup() +} + +func writeFile(path string, contents []byte) { + f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600) + if err != nil { + panic(err) + } + defer f.Close() + if _, err := f.Write(contents); err != nil { + panic(err) + } +} + +// newServer returns a new mock ssh server. +func newServer(t *testing.T) *server { + if testing.Short() { + t.Skip("skipping test due to -short") + } + dir, err := ioutil.TempDir("", "sshtest") + if err != nil { + t.Fatal(err) + } + f, err := os.Create(filepath.Join(dir, "sshd_config")) + if err != nil { + t.Fatal(err) + } + err = configTmpl.Execute(f, map[string]string{ + "Dir": dir, + }) + if err != nil { + t.Fatal(err) + } + f.Close() + + for k, v := range testdata.PEMBytes { + filename := "id_" + k + writeFile(filepath.Join(dir, filename), v) + writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k])) + } + + return &server{ + t: t, + configfile: f.Name(), + cleanup: func() { + if err := os.RemoveAll(dir); err != nil { + t.Error(err) + } + }, + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/testdata_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/testdata_test.go new file mode 100644 index 000000000..ae48c7516 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/test/testdata_test.go @@ -0,0 +1,64 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package test + +import ( + "crypto/rand" + "fmt" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]ssh.Signer + testPublicKeys map[string]ssh.PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]ssh.Signer, n) + testPublicKeys = make(map[string]ssh.PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &ssh.Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/testdata/doc.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/testdata/doc.go new file mode 100644 index 000000000..fcae47ca6 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/testdata/doc.go @@ -0,0 +1,8 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package contains test data shared between the various subpackages of +// the golang.org/x/crypto/ssh package. Under no circumstance should +// this data be used for production code. +package testdata // import "golang.org/x/crypto/ssh/testdata" diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/testdata/keys.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/testdata/keys.go new file mode 100644 index 000000000..5ff1c0e03 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/testdata/keys.go @@ -0,0 +1,43 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testdata + +var PEMBytes = map[string][]byte{ + "dsa": []byte(`-----BEGIN DSA PRIVATE KEY----- +MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB +lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3 +EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD +nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV +2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r +juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr +FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz +DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj +nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY +Fmsr0W6fHB9nhS4/UXM8 +-----END DSA PRIVATE KEY----- +`), + "ecdsa": []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49 +AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+ +6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA== +-----END EC PRIVATE KEY----- +`), + "rsa": []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIBOwIBAAJBALdGZxkXDAjsYk10ihwU6Id2KeILz1TAJuoq4tOgDWxEEGeTrcld +r/ZwVaFzjWzxaf6zQIJbfaSEAhqD5yo72+sCAwEAAQJBAK8PEVU23Wj8mV0QjwcJ +tZ4GcTUYQL7cF4+ezTCE9a1NrGnCP2RuQkHEKxuTVrxXt+6OF15/1/fuXnxKjmJC +nxkCIQDaXvPPBi0c7vAxGwNY9726x01/dNbHCE0CBtcotobxpwIhANbbQbh3JHVW +2haQh4fAG5mhesZKAGcxTyv4mQ7uMSQdAiAj+4dzMpJWdSzQ+qGHlHMIBvVHLkqB +y2VdEyF7DPCZewIhAI7GOI/6LDIFOvtPo6Bj2nNmyQ1HU6k/LRtNIXi4c9NJAiAr +rrxx26itVhJmcvoUhOjwuzSlP2bE5VHAvkGB352YBg== +-----END RSA PRIVATE KEY----- +`), + "user": []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49 +AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD +PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w== +-----END EC PRIVATE KEY----- +`), +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/testdata_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/testdata_test.go new file mode 100644 index 000000000..f2828c1b5 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/testdata_test.go @@ -0,0 +1,63 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package ssh + +import ( + "crypto/rand" + "fmt" + + "golang.org/x/crypto/ssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]Signer + testPublicKeys map[string]PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]Signer, n) + testPublicKeys = make(map[string]PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/transport.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/transport.go new file mode 100644 index 000000000..8351d378e --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/transport.go @@ -0,0 +1,332 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bufio" + "errors" + "io" +) + +const ( + gcmCipherID = "aes128-gcm@openssh.com" + aes128cbcID = "aes128-cbc" +) + +// packetConn represents a transport that implements packet based +// operations. +type packetConn interface { + // Encrypt and send a packet of data to the remote peer. + writePacket(packet []byte) error + + // Read a packet from the connection + readPacket() ([]byte, error) + + // Close closes the write-side of the connection. + Close() error +} + +// transport is the keyingTransport that implements the SSH packet +// protocol. +type transport struct { + reader connectionState + writer connectionState + + bufReader *bufio.Reader + bufWriter *bufio.Writer + rand io.Reader + + io.Closer + + // Initial H used for the session ID. Once assigned this does + // not change, even during subsequent key exchanges. + sessionID []byte +} + +// getSessionID returns the ID of the SSH connection. The return value +// should not be modified. +func (t *transport) getSessionID() []byte { + if t.sessionID == nil { + panic("session ID not set yet") + } + return t.sessionID +} + +// packetCipher represents a combination of SSH encryption/MAC +// protocol. A single instance should be used for one direction only. +type packetCipher interface { + // writePacket encrypts the packet and writes it to w. The + // contents of the packet are generally scrambled. + writePacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error + + // readPacket reads and decrypts a packet of data. The + // returned packet may be overwritten by future calls of + // readPacket. + readPacket(seqnum uint32, r io.Reader) ([]byte, error) +} + +// connectionState represents one side (read or write) of the +// connection. This is necessary because each direction has its own +// keys, and can even have its own algorithms +type connectionState struct { + packetCipher + seqNum uint32 + dir direction + pendingKeyChange chan packetCipher +} + +// prepareKeyChange sets up key material for a keychange. The key changes in +// both directions are triggered by reading and writing a msgNewKey packet +// respectively. +func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { + if t.sessionID == nil { + t.sessionID = kexResult.H + } + + kexResult.SessionID = t.sessionID + + if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil { + return err + } else { + t.reader.pendingKeyChange <- ciph + } + + if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil { + return err + } else { + t.writer.pendingKeyChange <- ciph + } + + return nil +} + +// Read and decrypt next packet. +func (t *transport) readPacket() ([]byte, error) { + return t.reader.readPacket(t.bufReader) +} + +func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) { + packet, err := s.packetCipher.readPacket(s.seqNum, r) + s.seqNum++ + if err == nil && len(packet) == 0 { + err = errors.New("ssh: zero length packet") + } + + if len(packet) > 0 && packet[0] == msgNewKeys { + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + default: + return nil, errors.New("ssh: got bogus newkeys message.") + } + } + + // The packet may point to an internal buffer, so copy the + // packet out here. + fresh := make([]byte, len(packet)) + copy(fresh, packet) + + return fresh, err +} + +func (t *transport) writePacket(packet []byte) error { + return t.writer.writePacket(t.bufWriter, t.rand, packet) +} + +func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error { + changeKeys := len(packet) > 0 && packet[0] == msgNewKeys + + err := s.packetCipher.writePacket(s.seqNum, w, rand, packet) + if err != nil { + return err + } + if err = w.Flush(); err != nil { + return err + } + s.seqNum++ + if changeKeys { + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + default: + panic("ssh: no key material for msgNewKeys") + } + } + return err +} + +func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport { + t := &transport{ + bufReader: bufio.NewReader(rwc), + bufWriter: bufio.NewWriter(rwc), + rand: rand, + reader: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + writer: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + Closer: rwc, + } + if isClient { + t.reader.dir = serverKeys + t.writer.dir = clientKeys + } else { + t.reader.dir = clientKeys + t.writer.dir = serverKeys + } + + return t +} + +type direction struct { + ivTag []byte + keyTag []byte + macKeyTag []byte +} + +var ( + serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}} + clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} +) + +// generateKeys generates key material for IV, MAC and encryption. +func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) { + cipherMode := cipherModes[algs.Cipher] + macMode := macModes[algs.MAC] + + iv = make([]byte, cipherMode.ivSize) + key = make([]byte, cipherMode.keySize) + macKey = make([]byte, macMode.keySize) + + generateKeyMaterial(iv, d.ivTag, kex) + generateKeyMaterial(key, d.keyTag, kex) + generateKeyMaterial(macKey, d.macKeyTag, kex) + return +} + +// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as +// described in RFC 4253, section 6.4. direction should either be serverKeys +// (to setup server->client keys) or clientKeys (for client->server keys). +func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) { + iv, key, macKey := generateKeys(d, algs, kex) + + if algs.Cipher == gcmCipherID { + return newGCMCipher(iv, key, macKey) + } + + if algs.Cipher == aes128cbcID { + return newAESCBCCipher(iv, key, macKey, algs) + } + + c := &streamPacketCipher{ + mac: macModes[algs.MAC].new(macKey), + } + c.macResult = make([]byte, c.mac.Size()) + + var err error + c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv) + if err != nil { + return nil, err + } + + return c, nil +} + +// generateKeyMaterial fills out with key material generated from tag, K, H +// and sessionId, as specified in RFC 4253, section 7.2. +func generateKeyMaterial(out, tag []byte, r *kexResult) { + var digestsSoFar []byte + + h := r.Hash.New() + for len(out) > 0 { + h.Reset() + h.Write(r.K) + h.Write(r.H) + + if len(digestsSoFar) == 0 { + h.Write(tag) + h.Write(r.SessionID) + } else { + h.Write(digestsSoFar) + } + + digest := h.Sum(nil) + n := copy(out, digest) + out = out[n:] + if len(out) > 0 { + digestsSoFar = append(digestsSoFar, digest...) + } + } +} + +const packageVersion = "SSH-2.0-Go" + +// Sends and receives a version line. The versionLine string should +// be US ASCII, start with "SSH-2.0-", and should not include a +// newline. exchangeVersions returns the other side's version line. +func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) { + // Contrary to the RFC, we do not ignore lines that don't + // start with "SSH-2.0-" to make the library usable with + // nonconforming servers. + for _, c := range versionLine { + // The spec disallows non US-ASCII chars, and + // specifically forbids null chars. + if c < 32 { + return nil, errors.New("ssh: junk character in version line") + } + } + if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil { + return + } + + them, err = readVersion(rw) + return them, err +} + +// maxVersionStringBytes is the maximum number of bytes that we'll +// accept as a version string. RFC 4253 section 4.2 limits this at 255 +// chars +const maxVersionStringBytes = 255 + +// Read version string as specified by RFC 4253, section 4.2. +func readVersion(r io.Reader) ([]byte, error) { + versionString := make([]byte, 0, 64) + var ok bool + var buf [1]byte + + for len(versionString) < maxVersionStringBytes { + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return nil, err + } + // The RFC says that the version should be terminated with \r\n + // but several SSH servers actually only send a \n. + if buf[0] == '\n' { + ok = true + break + } + + // non ASCII chars are disallowed, but we are lenient, + // since Go doesn't use null-terminated strings. + + // The RFC allows a comment after a space, however, + // all of it (version and comments) goes into the + // session hash. + versionString = append(versionString, buf[0]) + } + + if !ok { + return nil, errors.New("ssh: overflow reading version string") + } + + // There might be a '\r' on the end which we should remove. + if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' { + versionString = versionString[:len(versionString)-1] + } + return versionString, nil +} diff --git a/Godeps/_workspace/src/golang.org/x/crypto/ssh/transport_test.go b/Godeps/_workspace/src/golang.org/x/crypto/ssh/transport_test.go new file mode 100644 index 000000000..92d83abf9 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/crypto/ssh/transport_test.go @@ -0,0 +1,109 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "strings" + "testing" +) + +func TestReadVersion(t *testing.T) { + longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] + cases := map[string]string{ + "SSH-2.0-bla\r\n": "SSH-2.0-bla", + "SSH-2.0-bla\n": "SSH-2.0-bla", + longversion + "\r\n": longversion, + } + + for in, want := range cases { + result, err := readVersion(bytes.NewBufferString(in)) + if err != nil { + t.Errorf("readVersion(%q): %s", in, err) + } + got := string(result) + if got != want { + t.Errorf("got %q, want %q", got, want) + } + } +} + +func TestReadVersionError(t *testing.T) { + longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] + cases := []string{ + longversion + "too-long\r\n", + } + for _, in := range cases { + if _, err := readVersion(bytes.NewBufferString(in)); err == nil { + t.Errorf("readVersion(%q) should have failed", in) + } + } +} + +func TestExchangeVersionsBasic(t *testing.T) { + v := "SSH-2.0-bla" + buf := bytes.NewBufferString(v + "\r\n") + them, err := exchangeVersions(buf, []byte("xyz")) + if err != nil { + t.Errorf("exchangeVersions: %v", err) + } + + if want := "SSH-2.0-bla"; string(them) != want { + t.Errorf("got %q want %q for our version", them, want) + } +} + +func TestExchangeVersions(t *testing.T) { + cases := []string{ + "not\x000allowed", + "not allowed\n", + } + for _, c := range cases { + buf := bytes.NewBufferString("SSH-2.0-bla\r\n") + if _, err := exchangeVersions(buf, []byte(c)); err == nil { + t.Errorf("exchangeVersions(%q): should have failed", c) + } + } +} + +type closerBuffer struct { + bytes.Buffer +} + +func (b *closerBuffer) Close() error { + return nil +} + +func TestTransportMaxPacketWrite(t *testing.T) { + buf := &closerBuffer{} + tr := newTransport(buf, rand.Reader, true) + huge := make([]byte, maxPacket+1) + err := tr.writePacket(huge) + if err == nil { + t.Errorf("transport accepted write for a huge packet.") + } +} + +func TestTransportMaxPacketReader(t *testing.T) { + var header [5]byte + huge := make([]byte, maxPacket+128) + binary.BigEndian.PutUint32(header[0:], uint32(len(huge))) + // padding. + header[4] = 0 + + buf := &closerBuffer{} + buf.Write(header[:]) + buf.Write(huge) + + tr := newTransport(buf, rand.Reader, true) + _, err := tr.readPacket() + if err == nil { + t.Errorf("transport succeeded reading huge packet.") + } else if !strings.Contains(err.Error(), "large") { + t.Errorf("got %q, should mention %q", err.Error(), "large") + } +}