2
2
mirror of https://github.com/octoleo/restic.git synced 2024-11-02 11:46:36 +00:00

Merge pull request #1108 from restic/update-deps

Update vendored deps
This commit is contained in:
Alexander Neumann 2017-07-17 22:01:52 +02:00
commit e7575bf380
270 changed files with 39513 additions and 25400 deletions

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"path" "path"
"restic" "restic"
@ -197,53 +198,6 @@ func (be *Backend) Path() string {
return be.cfg.Prefix return be.cfg.Prefix
} }
// nopCloserFile wraps *os.File and overwrites the Close() method with method
// that does nothing. In addition, the method Len() is implemented, which
// returns the size of the file (filesize - current offset).
type nopCloserFile struct {
*os.File
}
func (f nopCloserFile) Close() error {
debug.Log("prevented Close()")
return nil
}
// Len returns the remaining length of the file (filesize - current offset).
func (f nopCloserFile) Len() int {
debug.Log("Len() called")
fi, err := f.Stat()
if err != nil {
panic(err)
}
pos, err := f.Seek(0, io.SeekCurrent)
if err != nil {
panic(err)
}
size := fi.Size() - pos
debug.Log("returning file size %v", size)
return int(size)
}
type lenner interface {
Len() int
io.Reader
}
// nopCloserLenner wraps a lenner and overwrites the Close() method with method
// that does nothing. In addition, the method Size() is implemented, which
// returns the size of the file (filesize - current offset).
type nopCloserLenner struct {
lenner
}
func (f *nopCloserLenner) Close() error {
debug.Log("prevented Close()")
return nil
}
// Save stores data in the backend at the handle. // Save stores data in the backend at the handle.
func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
debug.Log("Save %v", h) debug.Log("Save %v", h)
@ -262,15 +216,7 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err
} }
// prevent the HTTP client from closing a file // prevent the HTTP client from closing a file
if f, ok := rd.(*os.File); ok { rd = ioutil.NopCloser(rd)
debug.Log("reader is %#T, using nopCloserFile{}", rd)
rd = nopCloserFile{f}
} else if l, ok := rd.(lenner); ok {
debug.Log("reader is %#T, using nopCloserLenner{}", rd)
rd = nopCloserLenner{l}
} else {
debug.Log("reader is %#T, no specific workaround enabled", rd)
}
be.sem.GetToken() be.sem.GetToken()
debug.Log("PutObject(%v, %v)", be.cfg.Bucket, objName) debug.Log("PutObject(%v, %v)", be.cfg.Bucket, objName)
@ -479,8 +425,14 @@ func (be *Backend) Rename(h restic.Handle, l backend.Layout) error {
debug.Log(" %v -> %v", oldname, newname) debug.Log(" %v -> %v", oldname, newname)
coreClient := minio.Core{Client: be.client} src := minio.NewSourceInfo(be.cfg.Bucket, oldname, nil)
err := coreClient.CopyObject(be.cfg.Bucket, newname, path.Join(be.cfg.Bucket, oldname), minio.CopyConditions{})
dst, err := minio.NewDestinationInfo(be.cfg.Bucket, newname, nil, nil)
if err != nil {
return errors.Wrap(err, "NewDestinationInfo")
}
err = be.client.CopyObject(dst, src)
if err != nil && be.IsNotExist(err) { if err != nil && be.IsNotExist(err) {
debug.Log("copy failed: %v, seems to already have been renamed", err) debug.Log("copy failed: %v, seems to already have been renamed", err)
return nil return nil

50
vendor/manifest vendored
View File

@ -10,13 +10,13 @@
{ {
"importpath": "github.com/elithrar/simple-scrypt", "importpath": "github.com/elithrar/simple-scrypt",
"repository": "https://github.com/elithrar/simple-scrypt", "repository": "https://github.com/elithrar/simple-scrypt",
"revision": "2325946f714c95de4a6088202c402fbdfa64163b", "revision": "6724715de445c2e70cdafb7a1a14c8cfe0984210",
"branch": "master" "branch": "master"
}, },
{ {
"importpath": "github.com/go-ini/ini", "importpath": "github.com/go-ini/ini",
"repository": "https://github.com/go-ini/ini", "repository": "https://github.com/go-ini/ini",
"revision": "e7fea39b01aea8d5671f6858f0532f56e8bff3a5", "revision": "3d73f4b845efdf9989fffd4b4e562727744a34ba",
"branch": "master" "branch": "master"
}, },
{ {
@ -34,122 +34,122 @@
{ {
"importpath": "github.com/kurin/blazer", "importpath": "github.com/kurin/blazer",
"repository": "https://github.com/kurin/blazer", "repository": "https://github.com/kurin/blazer",
"revision": "d1b9d31c8641e46f2651fe564ee9ddb857c1ed29", "revision": "612082ed2430716569f1ec816fc6ade849020816",
"branch": "master" "branch": "master"
}, },
{ {
"importpath": "github.com/minio/go-homedir", "importpath": "github.com/minio/go-homedir",
"repository": "https://github.com/minio/go-homedir", "repository": "https://github.com/minio/go-homedir",
"revision": "0b1069c753c94b3633cc06a1995252dbcc27c7a6", "revision": "21304a94172ae3a09dee2cd86a12fb6f842138c7",
"branch": "master" "branch": "master"
}, },
{ {
"importpath": "github.com/minio/minio-go", "importpath": "github.com/minio/minio-go",
"repository": "https://github.com/minio/minio-go", "repository": "https://github.com/minio/minio-go",
"revision": "fe53a65ebc43b5d22626b29a19a3de81170e42d3", "revision": "bd8e1d8a93f006a0207e026353bf0644ffcdd320",
"branch": "master" "branch": "master"
}, },
{ {
"importpath": "github.com/ncw/swift", "importpath": "github.com/ncw/swift",
"repository": "https://github.com/ncw/swift", "repository": "https://github.com/ncw/swift",
"revision": "bf51ccd3b5c3a1f12ac762b4511c5f9f1ce6b26f", "revision": "9e6fdb8957a022d5780a78b58d6181c3580bb01f",
"branch": "master" "branch": "master"
}, },
{ {
"importpath": "github.com/pkg/errors", "importpath": "github.com/pkg/errors",
"repository": "https://github.com/pkg/errors", "repository": "https://github.com/pkg/errors",
"revision": "645ef00459ed84a119197bfb8d8205042c6df63d", "revision": "c605e284fe17294bda444b34710735b29d1a9d90",
"branch": "HEAD" "branch": "master"
}, },
{ {
"importpath": "github.com/pkg/profile", "importpath": "github.com/pkg/profile",
"repository": "https://github.com/pkg/profile", "repository": "https://github.com/pkg/profile",
"revision": "1c16f117a3ab788fdf0e334e623b8bccf5679866", "revision": "5b67d428864e92711fcbd2f8629456121a56d91f",
"branch": "HEAD" "branch": "master"
}, },
{ {
"importpath": "github.com/pkg/sftp", "importpath": "github.com/pkg/sftp",
"repository": "https://github.com/pkg/sftp", "repository": "https://github.com/pkg/sftp",
"revision": "8197a2e580736b78d704be0fc47b2324c0591a32", "revision": "314a5ccb89b21e053d7d96c3a706eacaf2b18231",
"branch": "master" "branch": "master"
}, },
{ {
"importpath": "github.com/pkg/xattr", "importpath": "github.com/pkg/xattr",
"repository": "https://github.com/pkg/xattr", "repository": "https://github.com/pkg/xattr",
"revision": "858d49c224b241ba9393e20f521f6a76f52dd482", "revision": "2c7218aab2e9980561010ef420b53d948749deaf",
"branch": "HEAD" "branch": "master"
}, },
{ {
"importpath": "github.com/restic/chunker", "importpath": "github.com/restic/chunker",
"repository": "https://github.com/restic/chunker", "repository": "https://github.com/restic/chunker",
"revision": "bb2ecf9a98e35a0b336ffc23fc515fb6e7961577", "revision": "1542d55ca53d2d8d7b38e890f7a4be90014356af",
"branch": "HEAD" "branch": "master"
}, },
{ {
"importpath": "github.com/spf13/cobra", "importpath": "github.com/spf13/cobra",
"repository": "https://github.com/spf13/cobra", "repository": "https://github.com/spf13/cobra",
"revision": "10f6b9d7e1631a54ad07c5c0fb71c28a1abfd3c2", "revision": "d994347edadc56d6a7f863775fb6887606685ae6",
"branch": "master" "branch": "master"
}, },
{ {
"importpath": "github.com/spf13/pflag", "importpath": "github.com/spf13/pflag",
"repository": "https://github.com/spf13/pflag", "repository": "https://github.com/spf13/pflag",
"revision": "2300d0f8576fe575f71aaa5b9bbe4e1b0dc2eb51", "revision": "e57e3eeb33f795204c1ca35f56c44f83227c6e66",
"branch": "master" "branch": "master"
}, },
{ {
"importpath": "golang.org/x/crypto/curve25519", "importpath": "golang.org/x/crypto/curve25519",
"repository": "https://go.googlesource.com/crypto", "repository": "https://go.googlesource.com/crypto",
"revision": "efac7f277b17c19894091e358c6130cb6bd51117", "revision": "7f7c0c2d75ebb4e32a21396ce36e87b6dadc91c9",
"branch": "master", "branch": "master",
"path": "/curve25519" "path": "/curve25519"
}, },
{ {
"importpath": "golang.org/x/crypto/ed25519", "importpath": "golang.org/x/crypto/ed25519",
"repository": "https://go.googlesource.com/crypto", "repository": "https://go.googlesource.com/crypto",
"revision": "efac7f277b17c19894091e358c6130cb6bd51117", "revision": "7f7c0c2d75ebb4e32a21396ce36e87b6dadc91c9",
"branch": "master", "branch": "master",
"path": "/ed25519" "path": "/ed25519"
}, },
{ {
"importpath": "golang.org/x/crypto/pbkdf2", "importpath": "golang.org/x/crypto/pbkdf2",
"repository": "https://go.googlesource.com/crypto", "repository": "https://go.googlesource.com/crypto",
"revision": "efac7f277b17c19894091e358c6130cb6bd51117", "revision": "7f7c0c2d75ebb4e32a21396ce36e87b6dadc91c9",
"branch": "master", "branch": "master",
"path": "/pbkdf2" "path": "/pbkdf2"
}, },
{ {
"importpath": "golang.org/x/crypto/poly1305", "importpath": "golang.org/x/crypto/poly1305",
"repository": "https://go.googlesource.com/crypto", "repository": "https://go.googlesource.com/crypto",
"revision": "efac7f277b17c19894091e358c6130cb6bd51117", "revision": "7f7c0c2d75ebb4e32a21396ce36e87b6dadc91c9",
"branch": "master", "branch": "master",
"path": "/poly1305" "path": "/poly1305"
}, },
{ {
"importpath": "golang.org/x/crypto/scrypt", "importpath": "golang.org/x/crypto/scrypt",
"repository": "https://go.googlesource.com/crypto", "repository": "https://go.googlesource.com/crypto",
"revision": "efac7f277b17c19894091e358c6130cb6bd51117", "revision": "7f7c0c2d75ebb4e32a21396ce36e87b6dadc91c9",
"branch": "master", "branch": "master",
"path": "/scrypt" "path": "/scrypt"
}, },
{ {
"importpath": "golang.org/x/crypto/ssh", "importpath": "golang.org/x/crypto/ssh",
"repository": "https://go.googlesource.com/crypto", "repository": "https://go.googlesource.com/crypto",
"revision": "efac7f277b17c19894091e358c6130cb6bd51117", "revision": "7f7c0c2d75ebb4e32a21396ce36e87b6dadc91c9",
"branch": "master", "branch": "master",
"path": "/ssh" "path": "/ssh"
}, },
{ {
"importpath": "golang.org/x/net/context", "importpath": "golang.org/x/net/context",
"repository": "https://go.googlesource.com/net", "repository": "https://go.googlesource.com/net",
"revision": "5602c733f70afc6dcec6766be0d5034d4c4f14de", "revision": "b3756b4b77d7b13260a0a2ec658753cf48922eac",
"branch": "master", "branch": "master",
"path": "/context" "path": "/context"
}, },
{ {
"importpath": "golang.org/x/sys/unix", "importpath": "golang.org/x/sys/unix",
"repository": "https://go.googlesource.com/sys", "repository": "https://go.googlesource.com/sys",
"revision": "f3918c30c5c2cb527c0b071a27c35120a6c0719a", "revision": "4cd6d1a821c7175768725b55ca82f14683a29ea4",
"branch": "master", "branch": "master",
"path": "/unix" "path": "/unix"
} }

File diff suppressed because one or more lines are too long

View File

@ -83,8 +83,8 @@ sec1, err := cfg.GetSection("Section")
sec2, err := cfg.GetSection("SecTIOn") sec2, err := cfg.GetSection("SecTIOn")
// key1 and key2 are the exactly same key object // key1 and key2 are the exactly same key object
key1, err := cfg.GetKey("Key") key1, err := sec1.GetKey("Key")
key2, err := cfg.GetKey("KeY") key2, err := sec2.GetKey("KeY")
``` ```
#### MySQL-like boolean key #### MySQL-like boolean key
@ -122,6 +122,12 @@ Take care that following format will be treated as comment:
If you want to save a value with `#` or `;`, please quote them with ``` ` ``` or ``` """ ```. If you want to save a value with `#` or `;`, please quote them with ``` ` ``` or ``` """ ```.
Alternatively, you can use following `LoadOptions` to completely ignore inline comments:
```go
cfg, err := LoadSources(LoadOptions{IgnoreInlineComment: true}, "app.ini"))
```
### Working with sections ### Working with sections
To get a section, you would need to: To get a section, you would need to:

View File

@ -76,8 +76,8 @@ sec1, err := cfg.GetSection("Section")
sec2, err := cfg.GetSection("SecTIOn") sec2, err := cfg.GetSection("SecTIOn")
// key1 和 key2 指向同一个键对象 // key1 和 key2 指向同一个键对象
key1, err := cfg.GetKey("Key") key1, err := sec1.GetKey("Key")
key2, err := cfg.GetKey("KeY") key2, err := sec2.GetKey("KeY")
``` ```
#### 类似 MySQL 配置中的布尔值键 #### 类似 MySQL 配置中的布尔值键
@ -115,6 +115,12 @@ key, err := sec.NewBooleanKey("skip-host-cache")
如果你希望使用包含 `#``;` 的值,请使用 ``` ` ``` 或 ``` """ ``` 进行包覆。 如果你希望使用包含 `#``;` 的值,请使用 ``` ` ``` 或 ``` """ ``` 进行包覆。
除此之外,您还可以通过 `LoadOptions` 完全忽略行内注释:
```go
cfg, err := LoadSources(LoadOptions{IgnoreInlineComment: true}, "app.ini"))
```
### 操作分区Section ### 操作分区Section
获取指定分区: 获取指定分区:

View File

@ -37,7 +37,7 @@ const (
// Maximum allowed depth when recursively substituing variable names. // Maximum allowed depth when recursively substituing variable names.
_DEPTH_VALUES = 99 _DEPTH_VALUES = 99
_VERSION = "1.27.0" _VERSION = "1.28.1"
) )
// Version returns current package version literal. // Version returns current package version literal.
@ -60,6 +60,9 @@ var (
// Explicitly write DEFAULT section header // Explicitly write DEFAULT section header
DefaultHeader = false DefaultHeader = false
// Indicate whether to put a line between sections
PrettySection = true
) )
func init() { func init() {
@ -504,7 +507,7 @@ func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
// In case key value contains "\n", "`", "\"", "#" or ";" // In case key value contains "\n", "`", "\"", "#" or ";"
if strings.ContainsAny(val, "\n`") { if strings.ContainsAny(val, "\n`") {
val = `"""` + val + `"""` val = `"""` + val + `"""`
} else if strings.ContainsAny(val, "#;") { } else if !f.options.IgnoreInlineComment && strings.ContainsAny(val, "#;") {
val = "`" + val + "`" val = "`" + val + "`"
} }
if _, err = buf.WriteString(equalSign + val + LineBreak); err != nil { if _, err = buf.WriteString(equalSign + val + LineBreak); err != nil {
@ -513,9 +516,11 @@ func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) {
} }
} }
// Put a line between sections if PrettySection {
if _, err = buf.WriteString(LineBreak); err != nil { // Put a line between sections
return 0, err if _, err = buf.WriteString(LineBreak); err != nil {
return 0, err
}
} }
} }

View File

@ -215,6 +215,13 @@ key2=value #comment2`))
So(cfg.Section("").Key("key1").String(), ShouldEqual, `value ;comment`) So(cfg.Section("").Key("key1").String(), ShouldEqual, `value ;comment`)
So(cfg.Section("").Key("key2").String(), ShouldEqual, `value #comment2`) So(cfg.Section("").Key("key2").String(), ShouldEqual, `value #comment2`)
var buf bytes.Buffer
cfg.WriteTo(&buf)
So(buf.String(), ShouldEqual, `key1 = value ;comment
key2 = value #comment2
`)
}) })
Convey("Load with boolean type keys", t, func() { Convey("Load with boolean type keys", t, func() {

View File

@ -189,7 +189,7 @@ func (p *parser) readContinuationLines(val string) (string, error) {
// are quotes \" or \'. // are quotes \" or \'.
// It returns false if any other parts also contain same kind of quotes. // It returns false if any other parts also contain same kind of quotes.
func hasSurroundedQuote(in string, quote byte) bool { func hasSurroundedQuote(in string, quote byte) bool {
return len(in) > 2 && in[0] == quote && in[len(in)-1] == quote && return len(in) >= 2 && in[0] == quote && in[len(in)-1] == quote &&
strings.IndexByte(in[1:], quote) == len(in)-2 strings.IndexByte(in[1:], quote) == len(in)-2
} }

View File

@ -78,7 +78,7 @@ func parseDelim(actual string) string {
var reflectTime = reflect.TypeOf(time.Now()).Kind() var reflectTime = reflect.TypeOf(time.Now()).Kind()
// setSliceWithProperType sets proper values to slice based on its type. // setSliceWithProperType sets proper values to slice based on its type.
func setSliceWithProperType(key *Key, field reflect.Value, delim string, allowShadow bool) error { func setSliceWithProperType(key *Key, field reflect.Value, delim string, allowShadow, isStrict bool) error {
var strs []string var strs []string
if allowShadow { if allowShadow {
strs = key.StringsWithShadows(delim) strs = key.StringsWithShadows(delim)
@ -92,26 +92,30 @@ func setSliceWithProperType(key *Key, field reflect.Value, delim string, allowSh
} }
var vals interface{} var vals interface{}
var err error
sliceOf := field.Type().Elem().Kind() sliceOf := field.Type().Elem().Kind()
switch sliceOf { switch sliceOf {
case reflect.String: case reflect.String:
vals = strs vals = strs
case reflect.Int: case reflect.Int:
vals, _ = key.parseInts(strs, true, false) vals, err = key.parseInts(strs, true, false)
case reflect.Int64: case reflect.Int64:
vals, _ = key.parseInt64s(strs, true, false) vals, err = key.parseInt64s(strs, true, false)
case reflect.Uint: case reflect.Uint:
vals, _ = key.parseUints(strs, true, false) vals, err = key.parseUints(strs, true, false)
case reflect.Uint64: case reflect.Uint64:
vals, _ = key.parseUint64s(strs, true, false) vals, err = key.parseUint64s(strs, true, false)
case reflect.Float64: case reflect.Float64:
vals, _ = key.parseFloat64s(strs, true, false) vals, err = key.parseFloat64s(strs, true, false)
case reflectTime: case reflectTime:
vals, _ = key.parseTimesFormat(time.RFC3339, strs, true, false) vals, err = key.parseTimesFormat(time.RFC3339, strs, true, false)
default: default:
return fmt.Errorf("unsupported type '[]%s'", sliceOf) return fmt.Errorf("unsupported type '[]%s'", sliceOf)
} }
if isStrict {
return err
}
slice := reflect.MakeSlice(field.Type(), numVals, numVals) slice := reflect.MakeSlice(field.Type(), numVals, numVals)
for i := 0; i < numVals; i++ { for i := 0; i < numVals; i++ {
@ -136,10 +140,17 @@ func setSliceWithProperType(key *Key, field reflect.Value, delim string, allowSh
return nil return nil
} }
func wrapStrictError(err error, isStrict bool) error {
if isStrict {
return err
}
return nil
}
// setWithProperType sets proper value to field based on its type, // setWithProperType sets proper value to field based on its type,
// but it does not return error for failing parsing, // but it does not return error for failing parsing,
// because we want to use default value that is already assigned to strcut. // because we want to use default value that is already assigned to strcut.
func setWithProperType(t reflect.Type, key *Key, field reflect.Value, delim string, allowShadow bool) error { func setWithProperType(t reflect.Type, key *Key, field reflect.Value, delim string, allowShadow, isStrict bool) error {
switch t.Kind() { switch t.Kind() {
case reflect.String: case reflect.String:
if len(key.String()) == 0 { if len(key.String()) == 0 {
@ -149,7 +160,7 @@ func setWithProperType(t reflect.Type, key *Key, field reflect.Value, delim stri
case reflect.Bool: case reflect.Bool:
boolVal, err := key.Bool() boolVal, err := key.Bool()
if err != nil { if err != nil {
return nil return wrapStrictError(err, isStrict)
} }
field.SetBool(boolVal) field.SetBool(boolVal)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@ -161,8 +172,8 @@ func setWithProperType(t reflect.Type, key *Key, field reflect.Value, delim stri
} }
intVal, err := key.Int64() intVal, err := key.Int64()
if err != nil || intVal == 0 { if err != nil {
return nil return wrapStrictError(err, isStrict)
} }
field.SetInt(intVal) field.SetInt(intVal)
// byte is an alias for uint8, so supporting uint8 breaks support for byte // byte is an alias for uint8, so supporting uint8 breaks support for byte
@ -176,24 +187,24 @@ func setWithProperType(t reflect.Type, key *Key, field reflect.Value, delim stri
uintVal, err := key.Uint64() uintVal, err := key.Uint64()
if err != nil { if err != nil {
return nil return wrapStrictError(err, isStrict)
} }
field.SetUint(uintVal) field.SetUint(uintVal)
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
floatVal, err := key.Float64() floatVal, err := key.Float64()
if err != nil { if err != nil {
return nil return wrapStrictError(err, isStrict)
} }
field.SetFloat(floatVal) field.SetFloat(floatVal)
case reflectTime: case reflectTime:
timeVal, err := key.Time() timeVal, err := key.Time()
if err != nil { if err != nil {
return nil return wrapStrictError(err, isStrict)
} }
field.Set(reflect.ValueOf(timeVal)) field.Set(reflect.ValueOf(timeVal))
case reflect.Slice: case reflect.Slice:
return setSliceWithProperType(key, field, delim, allowShadow) return setSliceWithProperType(key, field, delim, allowShadow, isStrict)
default: default:
return fmt.Errorf("unsupported type '%s'", t) return fmt.Errorf("unsupported type '%s'", t)
} }
@ -212,7 +223,7 @@ func parseTagOptions(tag string) (rawName string, omitEmpty bool, allowShadow bo
return rawName, omitEmpty, allowShadow return rawName, omitEmpty, allowShadow
} }
func (s *Section) mapTo(val reflect.Value) error { func (s *Section) mapTo(val reflect.Value, isStrict bool) error {
if val.Kind() == reflect.Ptr { if val.Kind() == reflect.Ptr {
val = val.Elem() val = val.Elem()
} }
@ -241,7 +252,7 @@ func (s *Section) mapTo(val reflect.Value) error {
if isAnonymous || isStruct { if isAnonymous || isStruct {
if sec, err := s.f.GetSection(fieldName); err == nil { if sec, err := s.f.GetSection(fieldName); err == nil {
if err = sec.mapTo(field); err != nil { if err = sec.mapTo(field, isStrict); err != nil {
return fmt.Errorf("error mapping field(%s): %v", fieldName, err) return fmt.Errorf("error mapping field(%s): %v", fieldName, err)
} }
continue continue
@ -250,7 +261,7 @@ func (s *Section) mapTo(val reflect.Value) error {
if key, err := s.GetKey(fieldName); err == nil { if key, err := s.GetKey(fieldName); err == nil {
delim := parseDelim(tpField.Tag.Get("delim")) delim := parseDelim(tpField.Tag.Get("delim"))
if err = setWithProperType(tpField.Type, key, field, delim, allowShadow); err != nil { if err = setWithProperType(tpField.Type, key, field, delim, allowShadow, isStrict); err != nil {
return fmt.Errorf("error mapping field(%s): %v", fieldName, err) return fmt.Errorf("error mapping field(%s): %v", fieldName, err)
} }
} }
@ -269,7 +280,22 @@ func (s *Section) MapTo(v interface{}) error {
return errors.New("cannot map to non-pointer struct") return errors.New("cannot map to non-pointer struct")
} }
return s.mapTo(val) return s.mapTo(val, false)
}
// MapTo maps section to given struct in strict mode,
// which returns all possible error including value parsing error.
func (s *Section) StrictMapTo(v interface{}) error {
typ := reflect.TypeOf(v)
val := reflect.ValueOf(v)
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
val = val.Elem()
} else {
return errors.New("cannot map to non-pointer struct")
}
return s.mapTo(val, true)
} }
// MapTo maps file to given struct. // MapTo maps file to given struct.
@ -277,6 +303,12 @@ func (f *File) MapTo(v interface{}) error {
return f.Section("").MapTo(v) return f.Section("").MapTo(v)
} }
// MapTo maps file to given struct in strict mode,
// which returns all possible error including value parsing error.
func (f *File) StrictMapTo(v interface{}) error {
return f.Section("").StrictMapTo(v)
}
// MapTo maps data sources to given struct with name mapper. // MapTo maps data sources to given struct with name mapper.
func MapToWithMapper(v interface{}, mapper NameMapper, source interface{}, others ...interface{}) error { func MapToWithMapper(v interface{}, mapper NameMapper, source interface{}, others ...interface{}) error {
cfg, err := Load(source, others...) cfg, err := Load(source, others...)
@ -287,11 +319,28 @@ func MapToWithMapper(v interface{}, mapper NameMapper, source interface{}, other
return cfg.MapTo(v) return cfg.MapTo(v)
} }
// StrictMapToWithMapper maps data sources to given struct with name mapper in strict mode,
// which returns all possible error including value parsing error.
func StrictMapToWithMapper(v interface{}, mapper NameMapper, source interface{}, others ...interface{}) error {
cfg, err := Load(source, others...)
if err != nil {
return err
}
cfg.NameMapper = mapper
return cfg.StrictMapTo(v)
}
// MapTo maps data sources to given struct. // MapTo maps data sources to given struct.
func MapTo(v, source interface{}, others ...interface{}) error { func MapTo(v, source interface{}, others ...interface{}) error {
return MapToWithMapper(v, nil, source, others...) return MapToWithMapper(v, nil, source, others...)
} }
// StrictMapTo maps data sources to given struct in strict mode,
// which returns all possible error including value parsing error.
func StrictMapTo(v, source interface{}, others ...interface{}) error {
return StrictMapToWithMapper(v, nil, source, others...)
}
// reflectSliceWithProperType does the opposite thing as setSliceWithProperType. // reflectSliceWithProperType does the opposite thing as setSliceWithProperType.
func reflectSliceWithProperType(key *Key, field reflect.Value, delim string) error { func reflectSliceWithProperType(key *Key, field reflect.Value, delim string) error {
slice := field.Slice(0, field.Len()) slice := field.Slice(0, field.Len())
@ -359,10 +408,11 @@ func isEmptyValue(v reflect.Value) bool {
return v.Uint() == 0 return v.Uint() == 0
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
return v.Float() == 0 return v.Float() == 0
case reflectTime:
return v.Interface().(time.Time).IsZero()
case reflect.Interface, reflect.Ptr: case reflect.Interface, reflect.Ptr:
return v.IsNil() return v.IsNil()
case reflectTime:
t, ok := v.Interface().(time.Time)
return ok && t.IsZero()
} }
return false return false
} }

View File

@ -229,6 +229,21 @@ func Test_Struct(t *testing.T) {
}) })
}) })
Convey("Map to struct in strict mode", t, func() {
cfg, err := Load([]byte(`
name=bruce
age=a30`))
So(err, ShouldBeNil)
type Strict struct {
Name string `ini:"name"`
Age int `ini:"age"`
}
s := new(Strict)
So(cfg.Section("").StrictMapTo(s), ShouldNotBeNil)
})
Convey("Reflect from struct", t, func() { Convey("Reflect from struct", t, func() {
type Embeded struct { type Embeded struct {
Dates []time.Time `delim:"|"` Dates []time.Time `delim:"|"`

View File

@ -77,10 +77,7 @@ Downloading is as simple as uploading:
```go ```go
func downloadFile(ctx context.Context, bucket *b2.Bucket, downloads int, src, dst string) error { func downloadFile(ctx context.Context, bucket *b2.Bucket, downloads int, src, dst string) error {
r, err := bucket.Object(src).NewReader(ctx) r := bucket.Object(src).NewReader(ctx)
if err != nil {
return err
}
defer r.Close() defer r.Close()
f, err := file.Create(dst) f, err := file.Create(dst)

View File

@ -332,11 +332,9 @@ func (o *Object) Name() string {
// Attrs returns an object's attributes. // Attrs returns an object's attributes.
func (o *Object) Attrs(ctx context.Context) (*Attrs, error) { func (o *Object) Attrs(ctx context.Context) (*Attrs, error) {
f, err := o.b.b.downloadFileByName(ctx, o.name, 0, 1) if err := o.ensure(ctx); err != nil {
if err != nil {
return nil, err return nil, err
} }
o.f = o.b.b.file(f.id())
fi, err := o.f.getFileInfo(ctx) fi, err := o.f.getFileInfo(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@ -585,20 +583,14 @@ func (b *Bucket) Reveal(ctx context.Context, name string) error {
} }
func (b *Bucket) getObject(ctx context.Context, name string) (*Object, error) { func (b *Bucket) getObject(ctx context.Context, name string) (*Object, error) {
fs, _, err := b.b.listFileNames(ctx, 1, name, "", "") fr, err := b.b.downloadFileByName(ctx, name, 0, 1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(fs) < 1 { fr.Close()
return nil, b2err{err: fmt.Errorf("%s: not found", name), notFoundErr: true}
}
f := fs[0]
if f.name() != name {
return nil, b2err{err: fmt.Errorf("%s: not found", name), notFoundErr: true}
}
return &Object{ return &Object{
name: name, name: name,
f: f, f: b.b.file(fr.id(), name),
b: b, b: b,
}, nil }, nil
} }

View File

@ -32,10 +32,12 @@ import (
const ( const (
bucketName = "b2-tests" bucketName = "b2-tests"
smallFileName = "TeenyTiny" smallFileName = "Teeny Tiny"
largeFileName = "BigBytes" largeFileName = "BigBytes"
) )
var gmux = &sync.Mutex{}
type testError struct { type testError struct {
retry bool retry bool
backoff time.Duration backoff time.Duration
@ -167,6 +169,8 @@ func (t *testBucket) startLargeFile(_ context.Context, name, _ string, _ map[str
func (t *testBucket) listFileNames(ctx context.Context, count int, cont, pfx, del string) ([]b2FileInterface, string, error) { func (t *testBucket) listFileNames(ctx context.Context, count int, cont, pfx, del string) ([]b2FileInterface, string, error) {
var f []string var f []string
gmux.Lock()
defer gmux.Unlock()
for name := range t.files { for name := range t.files {
f = append(f, name) f = append(f, name)
} }
@ -196,6 +200,8 @@ func (t *testBucket) listFileVersions(ctx context.Context, count int, a, b, c, d
} }
func (t *testBucket) downloadFileByName(_ context.Context, name string, offset, size int64) (b2FileReaderInterface, error) { func (t *testBucket) downloadFileByName(_ context.Context, name string, offset, size int64) (b2FileReaderInterface, error) {
gmux.Lock()
defer gmux.Unlock()
f := t.files[name] f := t.files[name]
end := int(offset + size) end := int(offset + size)
if end >= len(f) { if end >= len(f) {
@ -215,8 +221,8 @@ func (t *testBucket) hideFile(context.Context, string) (b2FileInterface, error)
func (t *testBucket) getDownloadAuthorization(context.Context, string, time.Duration) (string, error) { func (t *testBucket) getDownloadAuthorization(context.Context, string, time.Duration) (string, error) {
return "", nil return "", nil
} }
func (t *testBucket) baseURL() string { return "" } func (t *testBucket) baseURL() string { return "" }
func (t *testBucket) file(id string) b2FileInterface { return nil } func (t *testBucket) file(id, name string) b2FileInterface { return nil }
type testURL struct { type testURL struct {
files map[string]string files map[string]string
@ -229,6 +235,8 @@ func (t *testURL) uploadFile(_ context.Context, r io.Reader, _ int, name, _, _ s
if _, err := io.Copy(buf, r); err != nil { if _, err := io.Copy(buf, r); err != nil {
return nil, err return nil, err
} }
gmux.Lock()
defer gmux.Unlock()
t.files[name] = buf.String() t.files[name] = buf.String()
return &testFile{ return &testFile{
n: name, n: name,
@ -239,7 +247,6 @@ func (t *testURL) uploadFile(_ context.Context, r io.Reader, _ int, name, _, _ s
type testLargeFile struct { type testLargeFile struct {
name string name string
mux sync.Mutex
parts map[int][]byte parts map[int][]byte
files map[string]string files map[string]string
errs *errCont errs *errCont
@ -247,6 +254,8 @@ type testLargeFile struct {
func (t *testLargeFile) finishLargeFile(context.Context) (b2FileInterface, error) { func (t *testLargeFile) finishLargeFile(context.Context) (b2FileInterface, error) {
var total []byte var total []byte
gmux.Lock()
defer gmux.Unlock()
for i := 1; i <= len(t.parts); i++ { for i := 1; i <= len(t.parts); i++ {
total = append(total, t.parts[i]...) total = append(total, t.parts[i]...)
} }
@ -259,15 +268,15 @@ func (t *testLargeFile) finishLargeFile(context.Context) (b2FileInterface, error
} }
func (t *testLargeFile) getUploadPartURL(context.Context) (b2FileChunkInterface, error) { func (t *testLargeFile) getUploadPartURL(context.Context) (b2FileChunkInterface, error) {
gmux.Lock()
defer gmux.Unlock()
return &testFileChunk{ return &testFileChunk{
parts: t.parts, parts: t.parts,
mux: &t.mux,
errs: t.errs, errs: t.errs,
}, nil }, nil
} }
type testFileChunk struct { type testFileChunk struct {
mux *sync.Mutex
parts map[int][]byte parts map[int][]byte
errs *errCont errs *errCont
} }
@ -283,9 +292,9 @@ func (t *testFileChunk) uploadPart(_ context.Context, r io.Reader, _ string, _,
if err != nil { if err != nil {
return int(i), err return int(i), err
} }
t.mux.Lock() gmux.Lock()
defer gmux.Unlock()
t.parts[index] = buf.Bytes() t.parts[index] = buf.Bytes()
t.mux.Unlock()
return int(i), nil return int(i), nil
} }
@ -315,6 +324,8 @@ func (t *testFile) listParts(context.Context, int, int) ([]b2FilePartInterface,
} }
func (t *testFile) deleteFileVersion(context.Context) error { func (t *testFile) deleteFileVersion(context.Context) error {
gmux.Lock()
defer gmux.Unlock()
delete(t.files, t.n) delete(t.files, t.n)
return nil return nil
} }

View File

@ -54,7 +54,7 @@ type beBucketInterface interface {
hideFile(context.Context, string) (beFileInterface, error) hideFile(context.Context, string) (beFileInterface, error)
getDownloadAuthorization(context.Context, string, time.Duration) (string, error) getDownloadAuthorization(context.Context, string, time.Duration) (string, error)
baseURL() string baseURL() string
file(string) beFileInterface file(string, string) beFileInterface
} }
type beBucket struct { type beBucket struct {
@ -407,9 +407,9 @@ func (b *beBucket) baseURL() string {
return b.b2bucket.baseURL() return b.b2bucket.baseURL()
} }
func (b *beBucket) file(id string) beFileInterface { func (b *beBucket) file(id, name string) beFileInterface {
return &beFile{ return &beFile{
b2file: b.b2bucket.file(id), b2file: b.b2bucket.file(id, name),
ri: b.ri, ri: b.ri,
} }
} }

View File

@ -51,7 +51,7 @@ type b2BucketInterface interface {
hideFile(context.Context, string) (b2FileInterface, error) hideFile(context.Context, string) (b2FileInterface, error)
getDownloadAuthorization(context.Context, string, time.Duration) (string, error) getDownloadAuthorization(context.Context, string, time.Duration) (string, error)
baseURL() string baseURL() string
file(string) b2FileInterface file(string, string) b2FileInterface
} }
type b2URLInterface interface { type b2URLInterface interface {
@ -315,8 +315,12 @@ func (b *b2Bucket) listFileVersions(ctx context.Context, count int, nextName, ne
func (b *b2Bucket) downloadFileByName(ctx context.Context, name string, offset, size int64) (b2FileReaderInterface, error) { func (b *b2Bucket) downloadFileByName(ctx context.Context, name string, offset, size int64) (b2FileReaderInterface, error) {
fr, err := b.b.DownloadFileByName(ctx, name, offset, size) fr, err := b.b.DownloadFileByName(ctx, name, offset, size)
if err != nil { if err != nil {
if code, _ := base.Code(err); code == http.StatusRequestedRangeNotSatisfiable { code, _ := base.Code(err)
switch code {
case http.StatusRequestedRangeNotSatisfiable:
return nil, errNoMoreContent return nil, errNoMoreContent
case http.StatusNotFound:
return nil, b2err{err: err, notFoundErr: true}
} }
return nil, err return nil, err
} }
@ -339,7 +343,7 @@ func (b *b2Bucket) baseURL() string {
return b.b.BaseURL() return b.b.BaseURL()
} }
func (b *b2Bucket) file(id string) b2FileInterface { return &b2File{b.b.File(id)} } func (b *b2Bucket) file(id, name string) b2FileInterface { return &b2File{b.b.File(id, name)} }
func (b *b2URL) uploadFile(ctx context.Context, r io.Reader, size int, name, contentType, sha1 string, info map[string]string) (b2FileInterface, error) { func (b *b2URL) uploadFile(ctx context.Context, r io.Reader, size int, name, contentType, sha1 string, info map[string]string) (b2FileInterface, error) {
file, err := b.b.UploadFile(ctx, r, size, name, contentType, sha1, info) file, err := b.b.UploadFile(ctx, r, size, name, contentType, sha1, info)
@ -374,6 +378,9 @@ func (b *b2File) status() string {
} }
func (b *b2File) getFileInfo(ctx context.Context) (b2FileInfoInterface, error) { func (b *b2File) getFileInfo(ctx context.Context) (b2FileInfoInterface, error) {
if b.b.Info != nil {
return &b2FileInfo{b.b.Info}, nil
}
fi, err := b.b.GetFileInfo(ctx) fi, err := b.b.GetFileInfo(ctx)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -22,6 +22,7 @@ import (
"net/http" "net/http"
"os" "os"
"reflect" "reflect"
"sync"
"testing" "testing"
"time" "time"
@ -219,6 +220,14 @@ func TestAttrs(t *testing.T) {
LastModified: time.Unix(1464370149, 142000000), LastModified: time.Unix(1464370149, 142000000),
Info: map[string]string{}, // can't be nil Info: map[string]string{}, // can't be nil
}, },
&Attrs{
ContentType: "arbitrarystring",
Info: map[string]string{
"spaces": "string with spaces",
"unicode": "日本語",
"special": "&/!@_.~",
},
},
} }
table := []struct { table := []struct {
@ -615,6 +624,105 @@ func TestDuelingBuckets(t *testing.T) {
} }
} }
func TestNotExist(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 10*time.Minute)
defer cancel()
bucket, done := startLiveTest(ctx, t)
defer done()
if _, err := bucket.Object("not there").Attrs(ctx); !IsNotExist(err) {
t.Errorf("IsNotExist() on nonexistent object returned false (%v)", err)
}
}
func TestWriteEmpty(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 10*time.Minute)
defer cancel()
bucket, done := startLiveTest(ctx, t)
defer done()
_, _, err := writeFile(ctx, bucket, smallFileName, 0, 1e8)
if err != nil {
t.Fatal(err)
}
}
type rtCounter struct {
rt http.RoundTripper
trips int
sync.Mutex
}
func (rt *rtCounter) RoundTrip(r *http.Request) (*http.Response, error) {
rt.Lock()
defer rt.Unlock()
rt.trips++
return rt.rt.RoundTrip(r)
}
func TestAttrsNoRoundtrip(t *testing.T) {
rt := &rtCounter{rt: transport}
transport = rt
defer func() {
transport = rt.rt
}()
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 10*time.Minute)
defer cancel()
bucket, done := startLiveTest(ctx, t)
defer done()
_, _, err := writeFile(ctx, bucket, smallFileName, 1e6+42, 1e8)
if err != nil {
t.Fatal(err)
}
objs, _, err := bucket.ListObjects(ctx, 1, nil)
if err != nil {
t.Fatal(err)
}
if len(objs) != 1 {
t.Fatal("unexpected objects: got %d, want 1", len(objs))
}
trips := rt.trips
attrs, err := objs[0].Attrs(ctx)
if err != nil {
t.Fatal(err)
}
if attrs.Name != smallFileName {
t.Errorf("got the wrong object: got %q, want %q", attrs.Name, smallFileName)
}
if trips != rt.trips {
t.Errorf("Attrs() should not have caused any net traffic, but it did: old %d, new %d", trips, rt.trips)
}
}
func TestDeleteWithoutName(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 10*time.Minute)
defer cancel()
bucket, done := startLiveTest(ctx, t)
defer done()
_, _, err := writeFile(ctx, bucket, smallFileName, 1e6+42, 1e8)
if err != nil {
t.Fatal(err)
}
if err := bucket.Object(smallFileName).Delete(ctx); err != nil {
t.Fatal(err)
}
}
type object struct { type object struct {
o *Object o *Object
err error err error
@ -655,6 +763,8 @@ func listObjects(ctx context.Context, f func(context.Context, int, *Cursor) ([]*
return ch return ch
} }
var transport = http.DefaultTransport
func startLiveTest(ctx context.Context, t *testing.T) (*Bucket, func()) { func startLiveTest(ctx context.Context, t *testing.T) (*Bucket, func()) {
id := os.Getenv(apiID) id := os.Getenv(apiID)
key := os.Getenv(apiKey) key := os.Getenv(apiKey)
@ -662,7 +772,7 @@ func startLiveTest(ctx context.Context, t *testing.T) (*Bucket, func()) {
t.Skipf("B2_ACCOUNT_ID or B2_SECRET_KEY unset; skipping integration tests") t.Skipf("B2_ACCOUNT_ID or B2_SECRET_KEY unset; skipping integration tests")
return nil, nil return nil, nil
} }
client, err := NewClient(ctx, id, key, FailSomeUploads(), ExpireSomeAuthTokens()) client, err := NewClient(ctx, id, key, FailSomeUploads(), ExpireSomeAuthTokens(), Transport(transport))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return nil, nil return nil, nil

View File

@ -64,16 +64,17 @@ type Writer struct {
contentType string contentType string
info map[string]string info map[string]string
csize int csize int
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ready chan chunk ready chan chunk
wg sync.WaitGroup wg sync.WaitGroup
start sync.Once start sync.Once
once sync.Once once sync.Once
done sync.Once done sync.Once
file beLargeFileInterface file beLargeFileInterface
seen map[int]string seen map[int]string
everStarted bool
o *Object o *Object
name string name string
@ -202,6 +203,7 @@ func (w *Writer) thread() {
// Write satisfies the io.Writer interface. // Write satisfies the io.Writer interface.
func (w *Writer) Write(p []byte) (int, error) { func (w *Writer) Write(p []byte) (int, error) {
w.start.Do(func() { w.start.Do(func() {
w.everStarted = true
w.smux.Lock() w.smux.Lock()
w.smap = make(map[int]*meteredReader) w.smap = make(map[int]*meteredReader)
w.smux.Unlock() w.smux.Unlock()
@ -362,6 +364,9 @@ func (w *Writer) sendChunk() error {
// value of Close on all writers. // value of Close on all writers.
func (w *Writer) Close() error { func (w *Writer) Close() error {
w.done.Do(func() { w.done.Do(func() {
if !w.everStarted {
return
}
defer w.o.b.c.removeWriter(w) defer w.o.b.c.removeWriter(w)
defer w.w.Close() // TODO: log error defer w.w.Close() // TODO: log error
if w.cidx == 0 { if w.cidx == 0 {
@ -419,16 +424,21 @@ type meteredReader struct {
read int64 read int64
size int size int
r io.ReadSeeker r io.ReadSeeker
mux sync.Mutex
} }
func (mr *meteredReader) Read(p []byte) (int, error) { func (mr *meteredReader) Read(p []byte) (int, error) {
mr.mux.Lock()
defer mr.mux.Unlock()
n, err := mr.r.Read(p) n, err := mr.r.Read(p)
atomic.AddInt64(&mr.read, int64(n)) mr.read += int64(n)
return n, err return n, err
} }
func (mr *meteredReader) Seek(offset int64, whence int) (int64, error) { func (mr *meteredReader) Seek(offset int64, whence int) (int64, error) {
atomic.StoreInt64(&mr.read, offset) mr.mux.Lock()
defer mr.mux.Unlock()
mr.read = offset
return mr.r.Seek(offset, whence) return mr.r.Seek(offset, whence)
} }

View File

@ -277,7 +277,7 @@ func (rb *requestBody) getBody() io.Reader {
var reqID int64 var reqID int64
func (o *b2Options) makeRequest(ctx context.Context, method, verb, url string, b2req, b2resp interface{}, headers map[string]string, body *requestBody) error { func (o *b2Options) makeRequest(ctx context.Context, method, verb, uri string, b2req, b2resp interface{}, headers map[string]string, body *requestBody) error {
var args []byte var args []byte
if b2req != nil { if b2req != nil {
enc, err := json.Marshal(b2req) enc, err := json.Marshal(b2req)
@ -290,12 +290,15 @@ func (o *b2Options) makeRequest(ctx context.Context, method, verb, url string, b
size: int64(len(enc)), size: int64(len(enc)),
} }
} }
req, err := http.NewRequest(verb, url, body.getBody()) req, err := http.NewRequest(verb, uri, body.getBody())
if err != nil { if err != nil {
return err return err
} }
req.ContentLength = body.getSize() req.ContentLength = body.getSize()
for k, v := range headers { for k, v := range headers {
if strings.HasPrefix(k, "X-Bz-Info") || strings.HasPrefix(k, "X-Bz-File-Name") {
v = escape(v)
}
req.Header.Set(k, v) req.Header.Set(k, v)
} }
req.Header.Set("X-Blazer-Request-ID", fmt.Sprintf("%d", atomic.AddInt64(&reqID, 1))) req.Header.Set("X-Blazer-Request-ID", fmt.Sprintf("%d", atomic.AddInt64(&reqID, 1)))
@ -322,6 +325,7 @@ func (o *b2Options) makeRequest(ctx context.Context, method, verb, url string, b
} }
if reply.err != nil { if reply.err != nil {
// Connection errors are retryable. // Connection errors are retryable.
blog.V(2).Infof(">> %s uri: %v err: %v", method, req.URL, reply.err)
return b2err{ return b2err{
msg: reply.err.Error(), msg: reply.err.Error(),
retry: 1, retry: 1,
@ -613,20 +617,21 @@ type File struct {
Size int64 Size int64
Status string Status string
Timestamp time.Time Timestamp time.Time
Info *FileInfo
id string id string
b2 *B2 b2 *B2
} }
// File returns a bare File struct, but with the appropriate id and b2 // File returns a bare File struct, but with the appropriate id and b2
// interfaces. // interfaces.
func (b *Bucket) File(id string) *File { func (b *Bucket) File(id, name string) *File {
return &File{id: id, b2: b.b2} return &File{id: id, b2: b.b2, Name: name}
} }
// UploadFile wraps b2_upload_file. // UploadFile wraps b2_upload_file.
func (url *URL) UploadFile(ctx context.Context, r io.Reader, size int, name, contentType, sha1 string, info map[string]string) (*File, error) { func (u *URL) UploadFile(ctx context.Context, r io.Reader, size int, name, contentType, sha1 string, info map[string]string) (*File, error) {
headers := map[string]string{ headers := map[string]string{
"Authorization": url.token, "Authorization": u.token,
"X-Bz-File-Name": name, "X-Bz-File-Name": name,
"Content-Type": contentType, "Content-Type": contentType,
"Content-Length": fmt.Sprintf("%d", size), "Content-Length": fmt.Sprintf("%d", size),
@ -636,7 +641,7 @@ func (url *URL) UploadFile(ctx context.Context, r io.Reader, size int, name, con
headers[fmt.Sprintf("X-Bz-Info-%s", k)] = v headers[fmt.Sprintf("X-Bz-Info-%s", k)] = v
} }
b2resp := &b2types.UploadFileResponse{} b2resp := &b2types.UploadFileResponse{}
if err := url.b2.opts.makeRequest(ctx, "b2_upload_file", "POST", url.uri, nil, b2resp, headers, &requestBody{body: r, size: int64(size)}); err != nil { if err := u.b2.opts.makeRequest(ctx, "b2_upload_file", "POST", u.uri, nil, b2resp, headers, &requestBody{body: r, size: int64(size)}); err != nil {
return nil, err return nil, err
} }
return &File{ return &File{
@ -645,7 +650,7 @@ func (url *URL) UploadFile(ctx context.Context, r io.Reader, size int, name, con
Timestamp: millitime(b2resp.Timestamp), Timestamp: millitime(b2resp.Timestamp),
Status: b2resp.Action, Status: b2resp.Action,
id: b2resp.FileID, id: b2resp.FileID,
b2: url.b2, b2: u.b2,
}, nil }, nil
} }
@ -868,8 +873,17 @@ func (b *Bucket) ListFileNames(ctx context.Context, count int, continuation, pre
Size: f.Size, Size: f.Size,
Status: f.Action, Status: f.Action,
Timestamp: millitime(f.Timestamp), Timestamp: millitime(f.Timestamp),
id: f.FileID, Info: &FileInfo{
b2: b.b2, Name: f.Name,
SHA1: f.SHA1,
Size: f.Size,
ContentType: f.ContentType,
Info: f.Info,
Status: f.Action,
Timestamp: millitime(f.Timestamp),
},
id: f.FileID,
b2: b.b2,
}) })
} }
return files, cont, nil return files, cont, nil
@ -899,8 +913,17 @@ func (b *Bucket) ListFileVersions(ctx context.Context, count int, startName, sta
Size: f.Size, Size: f.Size,
Status: f.Action, Status: f.Action,
Timestamp: millitime(f.Timestamp), Timestamp: millitime(f.Timestamp),
id: f.FileID, Info: &FileInfo{
b2: b.b2, Name: f.Name,
SHA1: f.SHA1,
Size: f.Size,
ContentType: f.ContentType,
Info: f.Info,
Status: f.Action,
Timestamp: millitime(f.Timestamp),
},
id: f.FileID,
b2: b.b2,
}) })
} }
return files, b2resp.NextName, b2resp.NextID, nil return files, b2resp.NextName, b2resp.NextID, nil
@ -945,8 +968,8 @@ func mkRange(offset, size int64) string {
// DownloadFileByName wraps b2_download_file_by_name. // DownloadFileByName wraps b2_download_file_by_name.
func (b *Bucket) DownloadFileByName(ctx context.Context, name string, offset, size int64) (*FileReader, error) { func (b *Bucket) DownloadFileByName(ctx context.Context, name string, offset, size int64) (*FileReader, error) {
url := fmt.Sprintf("%s/file/%s/%s", b.b2.downloadURI, b.Name, name) uri := fmt.Sprintf("%s/file/%s/%s", b.b2.downloadURI, b.Name, name)
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequest("GET", uri, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -978,6 +1001,7 @@ func (b *Bucket) DownloadFileByName(ctx context.Context, name string, offset, si
} }
clen, err := strconv.ParseInt(reply.resp.Header.Get("Content-Length"), 10, 64) clen, err := strconv.ParseInt(reply.resp.Header.Get("Content-Length"), 10, 64)
if err != nil { if err != nil {
reply.resp.Body.Close()
return nil, err return nil, err
} }
info := make(map[string]string) info := make(map[string]string)
@ -985,8 +1009,17 @@ func (b *Bucket) DownloadFileByName(ctx context.Context, name string, offset, si
if !strings.HasPrefix(key, "X-Bz-Info-") { if !strings.HasPrefix(key, "X-Bz-Info-") {
continue continue
} }
name := strings.TrimPrefix(key, "X-Bz-Info-") name, err := unescape(strings.TrimPrefix(key, "X-Bz-Info-"))
info[name] = reply.resp.Header.Get(key) if err != nil {
reply.resp.Body.Close()
return nil, err
}
val, err := unescape(reply.resp.Header.Get(key))
if err != nil {
reply.resp.Body.Close()
return nil, err
}
info[name] = val
} }
return &FileReader{ return &FileReader{
ReadCloser: reply.resp.Body, ReadCloser: reply.resp.Body,
@ -1046,7 +1079,7 @@ func (f *File) GetFileInfo(ctx context.Context) (*FileInfo, error) {
f.Status = b2resp.Action f.Status = b2resp.Action
f.Name = b2resp.Name f.Name = b2resp.Name
f.Timestamp = millitime(b2resp.Timestamp) f.Timestamp = millitime(b2resp.Timestamp)
return &FileInfo{ f.Info = &FileInfo{
Name: b2resp.Name, Name: b2resp.Name,
SHA1: b2resp.SHA1, SHA1: b2resp.SHA1,
Size: b2resp.Size, Size: b2resp.Size,
@ -1054,5 +1087,6 @@ func (f *File) GetFileInfo(ctx context.Context) (*FileInfo, error) {
Info: b2resp.Info, Info: b2resp.Info,
Status: b2resp.Action, Status: b2resp.Action,
Timestamp: millitime(b2resp.Timestamp), Timestamp: millitime(b2resp.Timestamp),
}, nil }
return f.Info, nil
} }

View File

@ -17,10 +17,12 @@ package base
import ( import (
"bytes" "bytes"
"crypto/sha1" "crypto/sha1"
"encoding/json"
"fmt" "fmt"
"io" "io"
"os" "os"
"reflect" "reflect"
"strings"
"testing" "testing"
"time" "time"
@ -277,3 +279,140 @@ func compareFileAndInfo(t *testing.T, info *FileInfo, name, sha1 string, imap ma
t.Errorf("got %v, want %v", info.Info, imap) t.Errorf("got %v, want %v", info.Info, imap)
} }
} }
// from https://www.backblaze.com/b2/docs/string_encoding.html
var testCases = `[
{"fullyEncoded": "%20", "minimallyEncoded": "+", "string": " "},
{"fullyEncoded": "%21", "minimallyEncoded": "!", "string": "!"},
{"fullyEncoded": "%22", "minimallyEncoded": "%22", "string": "\""},
{"fullyEncoded": "%23", "minimallyEncoded": "%23", "string": "#"},
{"fullyEncoded": "%24", "minimallyEncoded": "$", "string": "$"},
{"fullyEncoded": "%25", "minimallyEncoded": "%25", "string": "%"},
{"fullyEncoded": "%26", "minimallyEncoded": "%26", "string": "&"},
{"fullyEncoded": "%27", "minimallyEncoded": "'", "string": "'"},
{"fullyEncoded": "%28", "minimallyEncoded": "(", "string": "("},
{"fullyEncoded": "%29", "minimallyEncoded": ")", "string": ")"},
{"fullyEncoded": "%2A", "minimallyEncoded": "*", "string": "*"},
{"fullyEncoded": "%2B", "minimallyEncoded": "%2B", "string": "+"},
{"fullyEncoded": "%2C", "minimallyEncoded": "%2C", "string": ","},
{"fullyEncoded": "%2D", "minimallyEncoded": "-", "string": "-"},
{"fullyEncoded": "%2E", "minimallyEncoded": ".", "string": "."},
{"fullyEncoded": "/", "minimallyEncoded": "/", "string": "/"},
{"fullyEncoded": "%30", "minimallyEncoded": "0", "string": "0"},
{"fullyEncoded": "%31", "minimallyEncoded": "1", "string": "1"},
{"fullyEncoded": "%32", "minimallyEncoded": "2", "string": "2"},
{"fullyEncoded": "%33", "minimallyEncoded": "3", "string": "3"},
{"fullyEncoded": "%34", "minimallyEncoded": "4", "string": "4"},
{"fullyEncoded": "%35", "minimallyEncoded": "5", "string": "5"},
{"fullyEncoded": "%36", "minimallyEncoded": "6", "string": "6"},
{"fullyEncoded": "%37", "minimallyEncoded": "7", "string": "7"},
{"fullyEncoded": "%38", "minimallyEncoded": "8", "string": "8"},
{"fullyEncoded": "%39", "minimallyEncoded": "9", "string": "9"},
{"fullyEncoded": "%3A", "minimallyEncoded": ":", "string": ":"},
{"fullyEncoded": "%3B", "minimallyEncoded": ";", "string": ";"},
{"fullyEncoded": "%3C", "minimallyEncoded": "%3C", "string": "<"},
{"fullyEncoded": "%3D", "minimallyEncoded": "=", "string": "="},
{"fullyEncoded": "%3E", "minimallyEncoded": "%3E", "string": ">"},
{"fullyEncoded": "%3F", "minimallyEncoded": "%3F", "string": "?"},
{"fullyEncoded": "%40", "minimallyEncoded": "@", "string": "@"},
{"fullyEncoded": "%41", "minimallyEncoded": "A", "string": "A"},
{"fullyEncoded": "%42", "minimallyEncoded": "B", "string": "B"},
{"fullyEncoded": "%43", "minimallyEncoded": "C", "string": "C"},
{"fullyEncoded": "%44", "minimallyEncoded": "D", "string": "D"},
{"fullyEncoded": "%45", "minimallyEncoded": "E", "string": "E"},
{"fullyEncoded": "%46", "minimallyEncoded": "F", "string": "F"},
{"fullyEncoded": "%47", "minimallyEncoded": "G", "string": "G"},
{"fullyEncoded": "%48", "minimallyEncoded": "H", "string": "H"},
{"fullyEncoded": "%49", "minimallyEncoded": "I", "string": "I"},
{"fullyEncoded": "%4A", "minimallyEncoded": "J", "string": "J"},
{"fullyEncoded": "%4B", "minimallyEncoded": "K", "string": "K"},
{"fullyEncoded": "%4C", "minimallyEncoded": "L", "string": "L"},
{"fullyEncoded": "%4D", "minimallyEncoded": "M", "string": "M"},
{"fullyEncoded": "%4E", "minimallyEncoded": "N", "string": "N"},
{"fullyEncoded": "%4F", "minimallyEncoded": "O", "string": "O"},
{"fullyEncoded": "%50", "minimallyEncoded": "P", "string": "P"},
{"fullyEncoded": "%51", "minimallyEncoded": "Q", "string": "Q"},
{"fullyEncoded": "%52", "minimallyEncoded": "R", "string": "R"},
{"fullyEncoded": "%53", "minimallyEncoded": "S", "string": "S"},
{"fullyEncoded": "%54", "minimallyEncoded": "T", "string": "T"},
{"fullyEncoded": "%55", "minimallyEncoded": "U", "string": "U"},
{"fullyEncoded": "%56", "minimallyEncoded": "V", "string": "V"},
{"fullyEncoded": "%57", "minimallyEncoded": "W", "string": "W"},
{"fullyEncoded": "%58", "minimallyEncoded": "X", "string": "X"},
{"fullyEncoded": "%59", "minimallyEncoded": "Y", "string": "Y"},
{"fullyEncoded": "%5A", "minimallyEncoded": "Z", "string": "Z"},
{"fullyEncoded": "%5B", "minimallyEncoded": "%5B", "string": "["},
{"fullyEncoded": "%5C", "minimallyEncoded": "%5C", "string": "\\"},
{"fullyEncoded": "%5D", "minimallyEncoded": "%5D", "string": "]"},
{"fullyEncoded": "%5E", "minimallyEncoded": "%5E", "string": "^"},
{"fullyEncoded": "%5F", "minimallyEncoded": "_", "string": "_"},
{"fullyEncoded": "%60", "minimallyEncoded": "%60", "string": "` + "`" + `"},
{"fullyEncoded": "%61", "minimallyEncoded": "a", "string": "a"},
{"fullyEncoded": "%62", "minimallyEncoded": "b", "string": "b"},
{"fullyEncoded": "%63", "minimallyEncoded": "c", "string": "c"},
{"fullyEncoded": "%64", "minimallyEncoded": "d", "string": "d"},
{"fullyEncoded": "%65", "minimallyEncoded": "e", "string": "e"},
{"fullyEncoded": "%66", "minimallyEncoded": "f", "string": "f"},
{"fullyEncoded": "%67", "minimallyEncoded": "g", "string": "g"},
{"fullyEncoded": "%68", "minimallyEncoded": "h", "string": "h"},
{"fullyEncoded": "%69", "minimallyEncoded": "i", "string": "i"},
{"fullyEncoded": "%6A", "minimallyEncoded": "j", "string": "j"},
{"fullyEncoded": "%6B", "minimallyEncoded": "k", "string": "k"},
{"fullyEncoded": "%6C", "minimallyEncoded": "l", "string": "l"},
{"fullyEncoded": "%6D", "minimallyEncoded": "m", "string": "m"},
{"fullyEncoded": "%6E", "minimallyEncoded": "n", "string": "n"},
{"fullyEncoded": "%6F", "minimallyEncoded": "o", "string": "o"},
{"fullyEncoded": "%70", "minimallyEncoded": "p", "string": "p"},
{"fullyEncoded": "%71", "minimallyEncoded": "q", "string": "q"},
{"fullyEncoded": "%72", "minimallyEncoded": "r", "string": "r"},
{"fullyEncoded": "%73", "minimallyEncoded": "s", "string": "s"},
{"fullyEncoded": "%74", "minimallyEncoded": "t", "string": "t"},
{"fullyEncoded": "%75", "minimallyEncoded": "u", "string": "u"},
{"fullyEncoded": "%76", "minimallyEncoded": "v", "string": "v"},
{"fullyEncoded": "%77", "minimallyEncoded": "w", "string": "w"},
{"fullyEncoded": "%78", "minimallyEncoded": "x", "string": "x"},
{"fullyEncoded": "%79", "minimallyEncoded": "y", "string": "y"},
{"fullyEncoded": "%7A", "minimallyEncoded": "z", "string": "z"},
{"fullyEncoded": "%7B", "minimallyEncoded": "%7B", "string": "{"},
{"fullyEncoded": "%7C", "minimallyEncoded": "%7C", "string": "|"},
{"fullyEncoded": "%7D", "minimallyEncoded": "%7D", "string": "}"},
{"fullyEncoded": "%7E", "minimallyEncoded": "~", "string": "~"},
{"fullyEncoded": "%7F", "minimallyEncoded": "%7F", "string": "\u007f"},
{"fullyEncoded": "%E8%87%AA%E7%94%B1", "minimallyEncoded": "%E8%87%AA%E7%94%B1", "string": "\u81ea\u7531"},
{"fullyEncoded": "%F0%90%90%80", "minimallyEncoded": "%F0%90%90%80", "string": "\ud801\udc00"}
]`
type testCase struct {
Full string `json:"fullyEncoded"`
Min string `json:"minimallyEncoded"`
Raw string `json:"string"`
}
func TestEscapes(t *testing.T) {
dec := json.NewDecoder(strings.NewReader(testCases))
var tcs []testCase
if err := dec.Decode(&tcs); err != nil {
t.Fatal(err)
}
for _, tc := range tcs {
en := escape(tc.Raw)
if !(en == tc.Full || en == tc.Min) {
t.Errorf("encode %q: got %q, want %q or %q", tc.Raw, en, tc.Min, tc.Full)
}
m, err := unescape(tc.Min)
if err != nil {
t.Errorf("decode %q: %v", tc.Min, err)
}
if m != tc.Raw {
t.Errorf("decode %q: got %q, want %q", tc.Min, m, tc.Raw)
}
f, err := unescape(tc.Full)
if err != nil {
t.Errorf("decode %q: %v", tc.Full, err)
}
if f != tc.Raw {
t.Errorf("decode %q: got %q, want %q", tc.Full, f, tc.Raw)
}
}
}

View File

@ -0,0 +1,81 @@
// Copyright 2017, Google
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package base
import (
"bytes"
"errors"
"fmt"
)
func noEscape(c byte) bool {
switch c {
case '.', '_', '-', '/', '~', '!', '$', '\'', '(', ')', '*', ';', '=', ':', '@':
return true
}
return false
}
func escape(s string) string {
// cribbed from url.go, kinda
b := &bytes.Buffer{}
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case c == '/':
b.WriteByte(c)
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9':
b.WriteByte(c)
case noEscape(c):
b.WriteByte(c)
default:
fmt.Fprintf(b, "%%%X", c)
}
}
return b.String()
}
func unescape(s string) (string, error) {
b := &bytes.Buffer{}
for i := 0; i < len(s); i++ {
c := s[i]
switch c {
case '/':
b.WriteString("/")
case '+':
b.WriteString(" ")
case '%':
if len(s)-i < 3 {
return "", errors.New("unescape: bad encoding")
}
b.WriteByte(unhex(s[i+1])<<4 | unhex(s[i+2]))
i += 2
default:
b.WriteByte(c)
}
}
return b.String(), nil
}
func unhex(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 0
}

View File

@ -0,0 +1,134 @@
// Copyright 2017, Google
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This is a simple program that will copy named files into or out of B2.
//
// To copy a file into B2:
//
// B2_ACCOUNT_ID=foo B2_ACCOUNT_KEY=bar simple /path/to/file b2://bucket/path/to/dst
//
// To copy a file out:
//
// B2_ACCOUNT_ID=foo B2_ACCOUNT_KEY=bar simple b2://bucket/path/to/file /path/to/dst
package main
import (
"context"
"flag"
"fmt"
"io"
"net/url"
"os"
"strings"
"github.com/kurin/blazer/b2"
)
func main() {
flag.Parse()
b2id := os.Getenv("B2_ACCOUNT_ID")
b2key := os.Getenv("B2_ACCOUNT_KEY")
args := flag.Args()
if len(args) != 2 {
fmt.Printf("Usage:\n\nsimple [src] [dst]\n")
return
}
src, dst := args[0], args[1]
ctx := context.Background()
c, err := b2.NewClient(ctx, b2id, b2key)
if err != nil {
fmt.Println(err)
return
}
var r io.ReadCloser
var w io.WriteCloser
if strings.HasPrefix(src, "b2://") {
reader, err := b2Reader(ctx, c, src)
if err != nil {
fmt.Println(err)
return
}
r = reader
} else {
f, err := os.Open(src)
if err != nil {
fmt.Println(err)
return
}
r = f
}
// Readers do not need their errors checked on close. (Also it's a little
// silly to defer this in main(), but.)
defer r.Close()
if strings.HasPrefix(dst, "b2://") {
writer, err := b2Writer(ctx, c, dst)
if err != nil {
fmt.Println(err)
return
}
w = writer
} else {
f, err := os.Create(dst)
if err != nil {
fmt.Println(err)
return
}
w = f
}
// Copy and check error.
if _, err := io.Copy(w, r); err != nil {
fmt.Println(err)
return
}
// It is very important to check the error of the writer.
if err := w.Close(); err != nil {
fmt.Println(err)
}
}
func b2Reader(ctx context.Context, c *b2.Client, path string) (io.ReadCloser, error) {
o, err := b2Obj(ctx, c, path)
if err != nil {
return nil, err
}
return o.NewReader(ctx), nil
}
func b2Writer(ctx context.Context, c *b2.Client, path string) (io.WriteCloser, error) {
o, err := b2Obj(ctx, c, path)
if err != nil {
return nil, err
}
return o.NewWriter(ctx), nil
}
func b2Obj(ctx context.Context, c *b2.Client, path string) (*b2.Object, error) {
uri, err := url.Parse(path)
if err != nil {
return nil, err
}
bucket, err := c.Bucket(ctx, uri.Host)
if err != nil {
return nil, err
}
// B2 paths must not begin with /, so trim it here.
return bucket.Object(strings.TrimPrefix(uri.Path, "/")), nil
}

View File

@ -171,14 +171,8 @@ type ListFileNamesRequest struct {
} }
type ListFileNamesResponse struct { type ListFileNamesResponse struct {
Continuation string `json:"nextFileName"` Continuation string `json:"nextFileName"`
Files []struct { Files []GetFileInfoResponse `json:"files"`
FileID string `json:"fileId"`
Name string `json:"fileName"`
Size int64 `json:"size"`
Action string `json:"action"`
Timestamp int64 `json:"uploadTimestamp"`
} `json:"files"`
} }
type ListFileVersionsRequest struct { type ListFileVersionsRequest struct {
@ -191,15 +185,9 @@ type ListFileVersionsRequest struct {
} }
type ListFileVersionsResponse struct { type ListFileVersionsResponse struct {
NextName string `json:"nextFileName"` NextName string `json:"nextFileName"`
NextID string `json:"nextFileId"` NextID string `json:"nextFileId"`
Files []struct { Files []GetFileInfoResponse `json:"files"`
FileID string `json:"fileId"`
Name string `json:"fileName"`
Size int64 `json:"size"`
Action string `json:"action"`
Timestamp int64 `json:"uploadTimestamp"`
} `json:"files"`
} }
type HideFileRequest struct { type HideFileRequest struct {
@ -218,6 +206,7 @@ type GetFileInfoRequest struct {
} }
type GetFileInfoResponse struct { type GetFileInfoResponse struct {
FileID string `json:"fileId"`
Name string `json:"fileName"` Name string `json:"fileName"`
SHA1 string `json:"contentSha1"` SHA1 string `json:"contentSha1"`
Size int64 `json:"contentLength"` Size int64 `json:"contentLength"`

View File

@ -34,7 +34,7 @@ func dir() (string, error) {
cmd.Stdout = &stdout cmd.Stdout = &stdout
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
// If "getent" is missing, ignore it // If "getent" is missing, ignore it
if err == exec.ErrNotFound { if err != exec.ErrNotFound {
return "", err return "", err
} }
} else { } else {

View File

@ -10,6 +10,10 @@ import (
// dir returns the homedir of current user for MS Windows OS. // dir returns the homedir of current user for MS Windows OS.
func dir() (string, error) { func dir() (string, error) {
// First prefer the HOME environmental variable
if home := os.Getenv("HOME"); home != "" {
return home, nil
}
drive := os.Getenv("HOMEDRIVE") drive := os.Getenv("HOMEDRIVE")
path := os.Getenv("HOMEPATH") path := os.Getenv("HOMEPATH")
home := drive + path home := drive + path

View File

@ -1,9 +1,9 @@
package homedir package homedir
import ( import (
"fmt"
"os" "os"
"os/user" "os/user"
"path/filepath"
"testing" "testing"
) )
@ -66,7 +66,7 @@ func TestExpand(t *testing.T) {
{ {
"~/foo", "~/foo",
fmt.Sprintf("%s/foo", u.HomeDir), filepath.Join(u.HomeDir, "foo"),
false, false,
}, },
@ -103,12 +103,12 @@ func TestExpand(t *testing.T) {
DisableCache = true DisableCache = true
defer func() { DisableCache = false }() defer func() { DisableCache = false }()
defer patchEnv("HOME", "/custom/path/")() defer patchEnv("HOME", "/custom/path/")()
expected := "/custom/path/foo/bar" expected := filepath.Join("/", "custom", "path", "foo/bar")
actual, err := Expand("~/foo/bar") actual, err := Expand("~/foo/bar")
if err != nil { if err != nil {
t.Errorf("No error is expected, got: %v", err) t.Errorf("No error is expected, got: %v", err)
} else if actual != "/custom/path/foo/bar" { } else if actual != expected {
t.Errorf("Expected: %v; actual: %v", expected, actual) t.Errorf("Expected: %v; actual: %v", expected, actual)
} }
} }

View File

@ -0,0 +1,532 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package minio
import (
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/minio/minio-go/pkg/s3utils"
)
// SSEInfo - represents Server-Side-Encryption parameters specified by
// a user.
type SSEInfo struct {
key []byte
algo string
}
// NewSSEInfo - specifies (binary or un-encoded) encryption key and
// algorithm name. If algo is empty, it defaults to "AES256". Ref:
// https://docs.aws.amazon.com/AmazonS3/latest/dev/ServerSideEncryptionCustomerKeys.html
func NewSSEInfo(key []byte, algo string) SSEInfo {
if algo == "" {
algo = "AES256"
}
return SSEInfo{key, algo}
}
// internal method that computes SSE-C headers
func (s *SSEInfo) getSSEHeaders(isCopySource bool) map[string]string {
if s == nil {
return nil
}
cs := ""
if isCopySource {
cs = "copy-source-"
}
return map[string]string{
"x-amz-" + cs + "server-side-encryption-customer-algorithm": s.algo,
"x-amz-" + cs + "server-side-encryption-customer-key": base64.StdEncoding.EncodeToString(s.key),
"x-amz-" + cs + "server-side-encryption-customer-key-MD5": base64.StdEncoding.EncodeToString(sumMD5(s.key)),
}
}
// GetSSEHeaders - computes and returns headers for SSE-C as key-value
// pairs. They can be set as metadata in PutObject* requests (for
// encryption) or be set as request headers in `Core.GetObject` (for
// decryption).
func (s *SSEInfo) GetSSEHeaders() map[string]string {
return s.getSSEHeaders(false)
}
// DestinationInfo - type with information about the object to be
// created via server-side copy requests, using the Compose API.
type DestinationInfo struct {
bucket, object string
// key for encrypting destination
encryption *SSEInfo
// if no user-metadata is provided, it is copied from source
// (when there is only once source object in the compose
// request)
userMetadata map[string]string
}
// NewDestinationInfo - creates a compose-object/copy-source
// destination info object.
//
// `encSSEC` is the key info for server-side-encryption with customer
// provided key. If it is nil, no encryption is performed.
//
// `userMeta` is the user-metadata key-value pairs to be set on the
// destination. The keys are automatically prefixed with `x-amz-meta-`
// if needed. If nil is passed, and if only a single source (of any
// size) is provided in the ComposeObject call, then metadata from the
// source is copied to the destination.
func NewDestinationInfo(bucket, object string, encryptSSEC *SSEInfo,
userMeta map[string]string) (d DestinationInfo, err error) {
// Input validation.
if err = s3utils.CheckValidBucketName(bucket); err != nil {
return d, err
}
if err = s3utils.CheckValidObjectName(object); err != nil {
return d, err
}
// Process custom-metadata to remove a `x-amz-meta-` prefix if
// present and validate that keys are distinct (after this
// prefix removal).
m := make(map[string]string)
for k, v := range userMeta {
if strings.HasPrefix(strings.ToLower(k), "x-amz-meta-") {
k = k[len("x-amz-meta-"):]
}
if _, ok := m[k]; ok {
return d, fmt.Errorf("Cannot add both %s and x-amz-meta-%s keys as custom metadata", k, k)
}
m[k] = v
}
return DestinationInfo{
bucket: bucket,
object: object,
encryption: encryptSSEC,
userMetadata: m,
}, nil
}
// getUserMetaHeadersMap - construct appropriate key-value pairs to send
// as headers from metadata map to pass into copy-object request. For
// single part copy-object (i.e. non-multipart object), enable the
// withCopyDirectiveHeader to set the `x-amz-metadata-directive` to
// `REPLACE`, so that metadata headers from the source are not copied
// over.
func (d *DestinationInfo) getUserMetaHeadersMap(withCopyDirectiveHeader bool) map[string]string {
if len(d.userMetadata) == 0 {
return nil
}
r := make(map[string]string)
if withCopyDirectiveHeader {
r["x-amz-metadata-directive"] = "REPLACE"
}
for k, v := range d.userMetadata {
r["x-amz-meta-"+k] = v
}
return r
}
// SourceInfo - represents a source object to be copied, using
// server-side copying APIs.
type SourceInfo struct {
bucket, object string
start, end int64
decryptKey *SSEInfo
// Headers to send with the upload-part-copy request involving
// this source object.
Headers http.Header
}
// NewSourceInfo - create a compose-object/copy-object source info
// object.
//
// `decryptSSEC` is the decryption key using server-side-encryption
// with customer provided key. It may be nil if the source is not
// encrypted.
func NewSourceInfo(bucket, object string, decryptSSEC *SSEInfo) SourceInfo {
r := SourceInfo{
bucket: bucket,
object: object,
start: -1, // range is unspecified by default
decryptKey: decryptSSEC,
Headers: make(http.Header),
}
// Set the source header
r.Headers.Set("x-amz-copy-source", s3utils.EncodePath(bucket+"/"+object))
// Assemble decryption headers for upload-part-copy request
for k, v := range decryptSSEC.getSSEHeaders(true) {
r.Headers.Set(k, v)
}
return r
}
// SetRange - Set the start and end offset of the source object to be
// copied. If this method is not called, the whole source object is
// copied.
func (s *SourceInfo) SetRange(start, end int64) error {
if start > end || start < 0 {
return ErrInvalidArgument("start must be non-negative, and start must be at most end.")
}
// Note that 0 <= start <= end
s.start, s.end = start, end
return nil
}
// SetMatchETagCond - Set ETag match condition. The object is copied
// only if the etag of the source matches the value given here.
func (s *SourceInfo) SetMatchETagCond(etag string) error {
if etag == "" {
return ErrInvalidArgument("ETag cannot be empty.")
}
s.Headers.Set("x-amz-copy-source-if-match", etag)
return nil
}
// SetMatchETagExceptCond - Set the ETag match exception
// condition. The object is copied only if the etag of the source is
// not the value given here.
func (s *SourceInfo) SetMatchETagExceptCond(etag string) error {
if etag == "" {
return ErrInvalidArgument("ETag cannot be empty.")
}
s.Headers.Set("x-amz-copy-source-if-none-match", etag)
return nil
}
// SetModifiedSinceCond - Set the modified since condition.
func (s *SourceInfo) SetModifiedSinceCond(modTime time.Time) error {
if modTime.IsZero() {
return ErrInvalidArgument("Input time cannot be 0.")
}
s.Headers.Set("x-amz-copy-source-if-modified-since", modTime.Format(http.TimeFormat))
return nil
}
// SetUnmodifiedSinceCond - Set the unmodified since condition.
func (s *SourceInfo) SetUnmodifiedSinceCond(modTime time.Time) error {
if modTime.IsZero() {
return ErrInvalidArgument("Input time cannot be 0.")
}
s.Headers.Set("x-amz-copy-source-if-unmodified-since", modTime.Format(http.TimeFormat))
return nil
}
// Helper to fetch size and etag of an object using a StatObject call.
func (s *SourceInfo) getProps(c Client) (size int64, etag string, userMeta map[string]string, err error) {
// Get object info - need size and etag here. Also, decryption
// headers are added to the stat request if given.
var objInfo ObjectInfo
rh := NewGetReqHeaders()
for k, v := range s.decryptKey.getSSEHeaders(false) {
rh.Set(k, v)
}
objInfo, err = c.statObject(s.bucket, s.object, rh)
if err != nil {
err = fmt.Errorf("Could not stat object - %s/%s: %v", s.bucket, s.object, err)
} else {
size = objInfo.Size
etag = objInfo.ETag
userMeta = make(map[string]string)
for k, v := range objInfo.Metadata {
if strings.HasPrefix(k, "x-amz-meta-") {
if len(v) > 0 {
userMeta[k] = v[0]
}
}
}
}
return
}
// uploadPartCopy - helper function to create a part in a multipart
// upload via an upload-part-copy request
// https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadUploadPartCopy.html
func (c Client) uploadPartCopy(bucket, object, uploadID string, partNumber int,
headers http.Header) (p CompletePart, err error) {
// Build query parameters
urlValues := make(url.Values)
urlValues.Set("partNumber", strconv.Itoa(partNumber))
urlValues.Set("uploadId", uploadID)
// Send upload-part-copy request
resp, err := c.executeMethod("PUT", requestMetadata{
bucketName: bucket,
objectName: object,
customHeader: headers,
queryValues: urlValues,
})
defer closeResponse(resp)
if err != nil {
return p, err
}
// Check if we got an error response.
if resp.StatusCode != http.StatusOK {
return p, httpRespToErrorResponse(resp, bucket, object)
}
// Decode copy-part response on success.
cpObjRes := copyObjectResult{}
err = xmlDecoder(resp.Body, &cpObjRes)
if err != nil {
return p, err
}
p.PartNumber, p.ETag = partNumber, cpObjRes.ETag
return p, nil
}
// ComposeObject - creates an object using server-side copying of
// existing objects. It takes a list of source objects (with optional
// offsets) and concatenates them into a new object using only
// server-side copying operations.
func (c Client) ComposeObject(dst DestinationInfo, srcs []SourceInfo) error {
if len(srcs) < 1 || len(srcs) > maxPartsCount {
return ErrInvalidArgument("There must be as least one and upto 10000 source objects.")
}
srcSizes := make([]int64, len(srcs))
var totalSize, size, totalParts int64
var srcUserMeta map[string]string
var etag string
var err error
for i, src := range srcs {
size, etag, srcUserMeta, err = src.getProps(c)
if err != nil {
return fmt.Errorf("Could not get source props for %s/%s: %v", src.bucket, src.object, err)
}
// Error out if client side encryption is used in this source object when
// more than one source objects are given.
if len(srcs) > 1 && src.Headers.Get("x-amz-meta-x-amz-key") != "" {
return ErrInvalidArgument(
fmt.Sprintf("Client side encryption is used in source object %s/%s", src.bucket, src.object))
}
// Since we did a HEAD to get size, we use the ETag
// value to make sure the object has not changed by
// the time we perform the copy. This is done, only if
// the user has not set their own ETag match
// condition.
if src.Headers.Get("x-amz-copy-source-if-match") == "" {
src.SetMatchETagCond(etag)
}
// Check if a segment is specified, and if so, is the
// segment within object bounds?
if src.start != -1 {
// Since range is specified,
// 0 <= src.start <= src.end
// so only invalid case to check is:
if src.end >= size {
return ErrInvalidArgument(
fmt.Sprintf("SourceInfo %d has invalid segment-to-copy [%d, %d] (size is %d)",
i, src.start, src.end, size))
}
size = src.end - src.start + 1
}
// Only the last source may be less than `absMinPartSize`
if size < absMinPartSize && i < len(srcs)-1 {
return ErrInvalidArgument(
fmt.Sprintf("SourceInfo %d is too small (%d) and it is not the last part", i, size))
}
// Is data to copy too large?
totalSize += size
if totalSize > maxMultipartPutObjectSize {
return ErrInvalidArgument(fmt.Sprintf("Cannot compose an object of size %d (> 5TiB)", totalSize))
}
// record source size
srcSizes[i] = size
// calculate parts needed for current source
totalParts += partsRequired(size)
// Do we need more parts than we are allowed?
if totalParts > maxPartsCount {
return ErrInvalidArgument(fmt.Sprintf(
"Your proposed compose object requires more than %d parts", maxPartsCount))
}
}
// Single source object case (i.e. when only one source is
// involved, it is being copied wholly and at most 5GiB in
// size).
if totalParts == 1 && srcs[0].start == -1 && totalSize <= maxPartSize {
h := srcs[0].Headers
// Add destination encryption headers
for k, v := range dst.encryption.getSSEHeaders(false) {
h.Set(k, v)
}
// If no user metadata is specified (and so, the
// for-loop below is not entered), metadata from the
// source is copied to the destination (due to
// single-part copy-object PUT request behaviour).
for k, v := range dst.getUserMetaHeadersMap(true) {
h.Set(k, v)
}
// Send copy request
resp, err := c.executeMethod("PUT", requestMetadata{
bucketName: dst.bucket,
objectName: dst.object,
customHeader: h,
})
defer closeResponse(resp)
if err != nil {
return err
}
// Check if we got an error response.
if resp.StatusCode != http.StatusOK {
return httpRespToErrorResponse(resp, dst.bucket, dst.object)
}
// Return nil on success.
return nil
}
// Now, handle multipart-copy cases.
// 1. Initiate a new multipart upload.
// Set user-metadata on the destination object. If no
// user-metadata is specified, and there is only one source,
// (only) then metadata from source is copied.
userMeta := dst.getUserMetaHeadersMap(false)
metaMap := userMeta
if len(userMeta) == 0 && len(srcs) == 1 {
metaMap = srcUserMeta
}
metaHeaders := make(map[string][]string)
for k, v := range metaMap {
metaHeaders[k] = append(metaHeaders[k], v)
}
uploadID, err := c.newUploadID(dst.bucket, dst.object, metaHeaders)
if err != nil {
return fmt.Errorf("Error creating new upload: %v", err)
}
// 2. Perform copy part uploads
objParts := []CompletePart{}
partIndex := 1
for i, src := range srcs {
h := src.Headers
// Add destination encryption headers
for k, v := range dst.encryption.getSSEHeaders(false) {
h.Set(k, v)
}
// calculate start/end indices of parts after
// splitting.
startIdx, endIdx := calculateEvenSplits(srcSizes[i], src)
for j, start := range startIdx {
end := endIdx[j]
// Add (or reset) source range header for
// upload part copy request.
h.Set("x-amz-copy-source-range",
fmt.Sprintf("bytes=%d-%d", start, end))
// make upload-part-copy request
complPart, err := c.uploadPartCopy(dst.bucket,
dst.object, uploadID, partIndex, h)
if err != nil {
return fmt.Errorf("Error in upload-part-copy - %v", err)
}
objParts = append(objParts, complPart)
partIndex++
}
}
// 3. Make final complete-multipart request.
_, err = c.completeMultipartUpload(dst.bucket, dst.object, uploadID,
completeMultipartUpload{Parts: objParts})
if err != nil {
err = fmt.Errorf("Error in complete-multipart request - %v", err)
}
return err
}
// partsRequired is ceiling(size / copyPartSize)
func partsRequired(size int64) int64 {
r := size / copyPartSize
if size%copyPartSize > 0 {
r++
}
return r
}
// calculateEvenSplits - computes splits for a source and returns
// start and end index slices. Splits happen evenly to be sure that no
// part is less than 5MiB, as that could fail the multipart request if
// it is not the last part.
func calculateEvenSplits(size int64, src SourceInfo) (startIndex, endIndex []int64) {
if size == 0 {
return
}
reqParts := partsRequired(size)
startIndex = make([]int64, reqParts)
endIndex = make([]int64, reqParts)
// Compute number of required parts `k`, as:
//
// k = ceiling(size / copyPartSize)
//
// Now, distribute the `size` bytes in the source into
// k parts as evenly as possible:
//
// r parts sized (q+1) bytes, and
// (k - r) parts sized q bytes, where
//
// size = q * k + r (by simple division of size by k,
// so that 0 <= r < k)
//
start := src.start
if start == -1 {
start = 0
}
quot, rem := size/reqParts, size%reqParts
nextStart := start
for j := int64(0); j < reqParts; j++ {
curPartSize := quot
if j < rem {
curPartSize++
}
cStart := nextStart
cEnd := cStart + curPartSize - 1
nextStart = cEnd + 1
startIndex[j], endIndex[j] = cStart, cEnd
}
return
}

View File

@ -0,0 +1,88 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package minio
import (
"reflect"
"testing"
)
const (
gb1 = 1024 * 1024 * 1024
gb5 = 5 * gb1
gb5p1 = gb5 + 1
gb10p1 = 2*gb5 + 1
gb10p2 = 2*gb5 + 2
)
func TestPartsRequired(t *testing.T) {
testCases := []struct {
size, ref int64
}{
{0, 0},
{1, 1},
{gb5, 1},
{2 * gb5, 2},
{gb10p1, 3},
{gb10p2, 3},
}
for i, testCase := range testCases {
res := partsRequired(testCase.size)
if res != testCase.ref {
t.Errorf("Test %d - output did not match with reference results", i+1)
}
}
}
func TestCalculateEvenSplits(t *testing.T) {
testCases := []struct {
// input size and source object
size int64
src SourceInfo
// output part-indexes
starts, ends []int64
}{
{0, SourceInfo{start: -1}, nil, nil},
{1, SourceInfo{start: -1}, []int64{0}, []int64{0}},
{1, SourceInfo{start: 0}, []int64{0}, []int64{0}},
{gb1, SourceInfo{start: -1}, []int64{0}, []int64{gb1 - 1}},
{gb5, SourceInfo{start: -1}, []int64{0}, []int64{gb5 - 1}},
// 2 part splits
{gb5p1, SourceInfo{start: -1}, []int64{0, gb5/2 + 1}, []int64{gb5 / 2, gb5}},
{gb5p1, SourceInfo{start: -1}, []int64{0, gb5/2 + 1}, []int64{gb5 / 2, gb5}},
// 3 part splits
{gb10p1, SourceInfo{start: -1},
[]int64{0, gb10p1/3 + 1, 2*gb10p1/3 + 1},
[]int64{gb10p1 / 3, 2 * gb10p1 / 3, gb10p1 - 1}},
{gb10p2, SourceInfo{start: -1},
[]int64{0, gb10p2 / 3, 2 * gb10p2 / 3},
[]int64{gb10p2/3 - 1, 2*gb10p2/3 - 1, gb10p2 - 1}},
}
for i, testCase := range testCases {
resStart, resEnd := calculateEvenSplits(testCase.size, testCase.src)
if !reflect.DeepEqual(testCase.starts, resStart) || !reflect.DeepEqual(testCase.ends, resEnd) {
t.Errorf("Test %d - output did not match with reference results", i+1)
}
}
}

View File

@ -679,12 +679,18 @@ func (c Client) getObject(bucketName, objectName string, reqHeaders RequestHeade
if contentType == "" { if contentType == "" {
contentType = "application/octet-stream" contentType = "application/octet-stream"
} }
var objectStat ObjectInfo
objectStat.ETag = md5sum objectStat := ObjectInfo{
objectStat.Key = objectName ETag: md5sum,
objectStat.Size = resp.ContentLength Key: objectName,
objectStat.LastModified = date Size: resp.ContentLength,
objectStat.ContentType = contentType LastModified: date,
ContentType: contentType,
// Extract only the relevant header keys describing the object.
// following function filters out a list of standard set of keys
// which are not part of object metadata.
Metadata: extractObjMetadata(resp.Header),
}
// do not close body here, caller will close // do not close body here, caller will close
return resp.Body, objectStat, nil return resp.Body, objectStat, nil

View File

@ -17,10 +17,8 @@
package minio package minio
import ( import (
"fmt"
"hash" "hash"
"io" "io"
"io/ioutil"
"math" "math"
"os" "os"
@ -78,55 +76,6 @@ func optimalPartInfo(objectSize int64) (totalPartsCount int, partSize int64, las
return totalPartsCount, partSize, lastPartSize, nil return totalPartsCount, partSize, lastPartSize, nil
} }
// hashCopyBuffer is identical to hashCopyN except that it doesn't take
// any size argument but takes a buffer argument and reader should be
// of io.ReaderAt interface.
//
// Stages reads from offsets into the buffer, if buffer is nil it is
// initialized to optimalBufferSize.
func hashCopyBuffer(hashAlgorithms map[string]hash.Hash, hashSums map[string][]byte, writer io.Writer, reader io.ReaderAt, buf []byte) (size int64, err error) {
hashWriter := writer
for _, v := range hashAlgorithms {
hashWriter = io.MultiWriter(hashWriter, v)
}
// Buffer is nil, initialize.
if buf == nil {
buf = make([]byte, optimalReadBufferSize)
}
// Offset to start reading from.
var readAtOffset int64
// Following block reads data at an offset from the input
// reader and copies data to into local temporary file.
for {
readAtSize, rerr := reader.ReadAt(buf, readAtOffset)
if rerr != nil {
if rerr != io.EOF {
return 0, rerr
}
}
writeSize, werr := hashWriter.Write(buf[:readAtSize])
if werr != nil {
return 0, werr
}
if readAtSize != writeSize {
return 0, fmt.Errorf("Read size was not completely written to writer. wanted %d, got %d - %s", readAtSize, writeSize, reportIssue)
}
readAtOffset += int64(writeSize)
size += int64(writeSize)
if rerr == io.EOF {
break
}
}
for k, v := range hashAlgorithms {
hashSums[k] = v.Sum(nil)
}
return size, err
}
// hashCopyN - Calculates chosen hashes up to partSize amount of bytes. // hashCopyN - Calculates chosen hashes up to partSize amount of bytes.
func hashCopyN(hashAlgorithms map[string]hash.Hash, hashSums map[string][]byte, writer io.Writer, reader io.Reader, partSize int64) (size int64, err error) { func hashCopyN(hashAlgorithms map[string]hash.Hash, hashSums map[string][]byte, writer io.Writer, reader io.Reader, partSize int64) (size int64, err error) {
hashWriter := writer hashWriter := writer
@ -167,27 +116,3 @@ func (c Client) newUploadID(bucketName, objectName string, metaData map[string][
} }
return initMultipartUploadResult.UploadID, nil return initMultipartUploadResult.UploadID, nil
} }
// computeHash - Calculates hashes for an input read Seeker.
func computeHash(hashAlgorithms map[string]hash.Hash, hashSums map[string][]byte, reader io.ReadSeeker) (size int64, err error) {
hashWriter := ioutil.Discard
for _, v := range hashAlgorithms {
hashWriter = io.MultiWriter(hashWriter, v)
}
// If no buffer is provided, no need to allocate just use io.Copy.
size, err = io.Copy(hashWriter, reader)
if err != nil {
return 0, err
}
// Seek back reader to the beginning location.
if _, err := reader.Seek(0, 0); err != nil {
return 0, err
}
for k, v := range hashAlgorithms {
hashSums[k] = v.Sum(nil)
}
return size, nil
}

View File

@ -16,57 +16,7 @@
package minio package minio
import ( // CopyObject - copy a source object into a new object
"net/http" func (c Client) CopyObject(dst DestinationInfo, src SourceInfo) error {
return c.ComposeObject(dst, []SourceInfo{src})
"github.com/minio/minio-go/pkg/s3utils"
)
// CopyObject - copy a source object into a new object with the provided name in the provided bucket
func (c Client) CopyObject(bucketName string, objectName string, objectSource string, cpCond CopyConditions) error {
// Input validation.
if err := s3utils.CheckValidBucketName(bucketName); err != nil {
return err
}
if err := s3utils.CheckValidObjectName(objectName); err != nil {
return err
}
if objectSource == "" {
return ErrInvalidArgument("Object source cannot be empty.")
}
// customHeaders apply headers.
customHeaders := make(http.Header)
for _, cond := range cpCond.conditions {
customHeaders.Set(cond.key, cond.value)
}
// Set copy source.
customHeaders.Set("x-amz-copy-source", s3utils.EncodePath(objectSource))
// Execute PUT on objectName.
resp, err := c.executeMethod("PUT", requestMetadata{
bucketName: bucketName,
objectName: objectName,
customHeader: customHeaders,
})
defer closeResponse(resp)
if err != nil {
return err
}
if resp != nil {
if resp.StatusCode != http.StatusOK {
return httpRespToErrorResponse(resp, bucketName, objectName)
}
}
// Decode copy response on success.
cpObjRes := copyObjectResult{}
err = xmlDecoder(resp.Body, &cpObjRes)
if err != nil {
return err
}
// Return nil on success.
return nil
} }

View File

@ -0,0 +1,46 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2015 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package minio
import (
"io"
"github.com/minio/minio-go/pkg/encrypt"
)
// PutEncryptedObject - Encrypt and store object.
func (c Client) PutEncryptedObject(bucketName, objectName string, reader io.Reader, encryptMaterials encrypt.Materials, metadata map[string][]string, progress io.Reader) (n int64, err error) {
if encryptMaterials == nil {
return 0, ErrInvalidArgument("Unable to recognize empty encryption properties")
}
if err := encryptMaterials.SetupEncryptMode(reader); err != nil {
return 0, err
}
if metadata == nil {
metadata = make(map[string][]string)
}
// Set the necessary encryption headers, for future decryption.
metadata[amzHeaderIV] = []string{encryptMaterials.GetIV()}
metadata[amzHeaderKey] = []string{encryptMaterials.GetKey()}
metadata[amzHeaderMatDesc] = []string{encryptMaterials.GetDesc()}
return c.putObjectMultipart(bucketName, objectName, encryptMaterials, -1, metadata, progress)
}

View File

@ -17,13 +17,9 @@
package minio package minio
import ( import (
"fmt"
"io"
"io/ioutil"
"mime" "mime"
"os" "os"
"path/filepath" "path/filepath"
"sort"
"github.com/minio/minio-go/pkg/s3utils" "github.com/minio/minio-go/pkg/s3utils"
) )
@ -55,11 +51,6 @@ func (c Client) FPutObject(bucketName, objectName, filePath, contentType string)
// Save the file size. // Save the file size.
fileSize := fileStat.Size() fileSize := fileStat.Size()
// Check for largest object size allowed.
if fileSize > int64(maxMultipartPutObjectSize) {
return 0, ErrEntityTooLarge(fileSize, maxMultipartPutObjectSize, bucketName, objectName)
}
objMetadata := make(map[string][]string) objMetadata := make(map[string][]string)
// Set contentType based on filepath extension if not given or default // Set contentType based on filepath extension if not given or default
@ -71,195 +62,5 @@ func (c Client) FPutObject(bucketName, objectName, filePath, contentType string)
} }
objMetadata["Content-Type"] = []string{contentType} objMetadata["Content-Type"] = []string{contentType}
return c.putObjectCommon(bucketName, objectName, fileReader, fileSize, objMetadata, nil)
// NOTE: Google Cloud Storage multipart Put is not compatible with Amazon S3 APIs.
if s3utils.IsGoogleEndpoint(c.endpointURL) {
// Do not compute MD5 for Google Cloud Storage.
return c.putObjectNoChecksum(bucketName, objectName, fileReader, fileSize, objMetadata, nil)
}
// Small object upload is initiated for uploads for input data size smaller than 5MiB.
if fileSize < minPartSize && fileSize >= 0 {
return c.putObjectSingle(bucketName, objectName, fileReader, fileSize, objMetadata, nil)
}
// Upload all large objects as multipart.
n, err = c.putObjectMultipartFromFile(bucketName, objectName, fileReader, fileSize, objMetadata, nil)
if err != nil {
errResp := ToErrorResponse(err)
// Verify if multipart functionality is not available, if not
// fall back to single PutObject operation.
if errResp.Code == "NotImplemented" {
// If size of file is greater than '5GiB' fail.
if fileSize > maxSinglePutObjectSize {
return 0, ErrEntityTooLarge(fileSize, maxSinglePutObjectSize, bucketName, objectName)
}
// Fall back to uploading as single PutObject operation.
return c.putObjectSingle(bucketName, objectName, fileReader, fileSize, objMetadata, nil)
}
return n, err
}
return n, nil
}
// putObjectMultipartFromFile - Creates object from contents of *os.File
//
// NOTE: This function is meant to be used for readers with local
// file as in *os.File. This function effectively utilizes file
// system capabilities of reading from specific sections and not
// having to create temporary files.
func (c Client) putObjectMultipartFromFile(bucketName, objectName string, fileReader io.ReaderAt, fileSize int64, metaData map[string][]string, progress io.Reader) (int64, error) {
// Input validation.
if err := s3utils.CheckValidBucketName(bucketName); err != nil {
return 0, err
}
if err := s3utils.CheckValidObjectName(objectName); err != nil {
return 0, err
}
// Initiate a new multipart upload.
uploadID, err := c.newUploadID(bucketName, objectName, metaData)
if err != nil {
return 0, err
}
// Total data read and written to server. should be equal to 'size' at the end of the call.
var totalUploadedSize int64
// Complete multipart upload.
var complMultipartUpload completeMultipartUpload
// Calculate the optimal parts info for a given size.
totalPartsCount, partSize, lastPartSize, err := optimalPartInfo(fileSize)
if err != nil {
return 0, err
}
// Create a channel to communicate a part was uploaded.
// Buffer this to 10000, the maximum number of parts allowed by S3.
uploadedPartsCh := make(chan uploadedPartRes, 10000)
// Create a channel to communicate which part to upload.
// Buffer this to 10000, the maximum number of parts allowed by S3.
uploadPartsCh := make(chan uploadPartReq, 10000)
// Just for readability.
lastPartNumber := totalPartsCount
// Initialize parts uploaded map.
partsInfo := make(map[int]ObjectPart)
// Send each part through the partUploadCh to be uploaded.
for p := 1; p <= totalPartsCount; p++ {
part, ok := partsInfo[p]
if ok {
uploadPartsCh <- uploadPartReq{PartNum: p, Part: &part}
} else {
uploadPartsCh <- uploadPartReq{PartNum: p, Part: nil}
}
}
close(uploadPartsCh)
// Use three 'workers' to upload parts in parallel.
for w := 1; w <= totalWorkers; w++ {
go func() {
// Deal with each part as it comes through the channel.
for uploadReq := range uploadPartsCh {
// Add hash algorithms that need to be calculated by computeHash()
// In case of a non-v4 signature or https connection, sha256 is not needed.
hashAlgos, hashSums := c.hashMaterials()
// If partNumber was not uploaded we calculate the missing
// part offset and size. For all other part numbers we
// calculate offset based on multiples of partSize.
readOffset := int64(uploadReq.PartNum-1) * partSize
missingPartSize := partSize
// As a special case if partNumber is lastPartNumber, we
// calculate the offset based on the last part size.
if uploadReq.PartNum == lastPartNumber {
readOffset = (fileSize - lastPartSize)
missingPartSize = lastPartSize
}
// Get a section reader on a particular offset.
sectionReader := io.NewSectionReader(fileReader, readOffset, missingPartSize)
var prtSize int64
var err error
prtSize, err = computeHash(hashAlgos, hashSums, sectionReader)
if err != nil {
uploadedPartsCh <- uploadedPartRes{
Error: err,
}
// Exit the goroutine.
return
}
// Proceed to upload the part.
var objPart ObjectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID, sectionReader, uploadReq.PartNum,
hashSums["md5"], hashSums["sha256"], prtSize)
if err != nil {
uploadedPartsCh <- uploadedPartRes{
Error: err,
}
// Exit the goroutine.
return
}
// Save successfully uploaded part metadata.
uploadReq.Part = &objPart
// Return through the channel the part size.
uploadedPartsCh <- uploadedPartRes{
Size: missingPartSize,
PartNum: uploadReq.PartNum,
Part: uploadReq.Part,
Error: nil,
}
}
}()
}
// Retrieve each uploaded part once it is done.
for u := 1; u <= totalPartsCount; u++ {
uploadRes := <-uploadedPartsCh
if uploadRes.Error != nil {
return totalUploadedSize, uploadRes.Error
}
// Retrieve each uploaded part and store it to be completed.
part := uploadRes.Part
if part == nil {
return totalUploadedSize, ErrInvalidArgument(fmt.Sprintf("Missing part number %d", uploadRes.PartNum))
}
// Update the total uploaded size.
totalUploadedSize += uploadRes.Size
// Update the progress bar if there is one.
if progress != nil {
if _, err = io.CopyN(ioutil.Discard, progress, uploadRes.Size); err != nil {
return totalUploadedSize, err
}
}
// Store the part to be completed.
complMultipartUpload.Parts = append(complMultipartUpload.Parts, CompletePart{
ETag: part.ETag,
PartNumber: part.PartNumber,
})
}
// Verify if we uploaded all data.
if totalUploadedSize != fileSize {
return totalUploadedSize, ErrUnexpectedEOF(totalUploadedSize, fileSize, bucketName, objectName)
}
// Sort all completed parts.
sort.Sort(completedParts(complMultipartUpload.Parts))
_, err = c.completeMultipartUpload(bucketName, objectName, uploadID, complMultipartUpload)
if err != nil {
return totalUploadedSize, err
}
// Return final size.
return totalUploadedSize, nil
} }

View File

@ -24,7 +24,6 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"os"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -32,161 +31,60 @@ import (
"github.com/minio/minio-go/pkg/s3utils" "github.com/minio/minio-go/pkg/s3utils"
) )
// Comprehensive put object operation involving multipart uploads. func (c Client) putObjectMultipart(bucketName, objectName string, reader io.Reader, size int64,
// metadata map[string][]string, progress io.Reader) (n int64, err error) {
// Following code handles these types of readers. n, err = c.putObjectMultipartNoStream(bucketName, objectName, reader, size, metadata, progress)
// if err != nil {
// - *os.File errResp := ToErrorResponse(err)
// - *minio.Object // Verify if multipart functionality is not available, if not
// - Any reader which has a method 'ReadAt()' // fall back to single PutObject operation.
// if errResp.Code == "AccessDenied" && strings.Contains(errResp.Message, "Access Denied") {
func (c Client) putObjectMultipart(bucketName, objectName string, reader io.Reader, size int64, metaData map[string][]string, progress io.Reader) (n int64, err error) { // Verify if size of reader is greater than '5GiB'.
if size > 0 && size > minPartSize { if size > maxSinglePutObjectSize {
// Verify if reader is *os.File, then use file system functionalities. return 0, ErrEntityTooLarge(size, maxSinglePutObjectSize, bucketName, objectName)
if isFile(reader) { }
return c.putObjectMultipartFromFile(bucketName, objectName, reader.(*os.File), size, metaData, progress) // Fall back to uploading as single PutObject operation.
} return c.putObjectNoChecksum(bucketName, objectName, reader, size, metadata, progress)
// Verify if reader is *minio.Object or io.ReaderAt.
// NOTE: Verification of object is kept for a specific purpose
// while it is going to be duck typed similar to io.ReaderAt.
// It is to indicate that *minio.Object implements io.ReaderAt.
// and such a functionality is used in the subsequent code
// path.
if isObject(reader) || isReadAt(reader) {
return c.putObjectMultipartFromReadAt(bucketName, objectName, reader.(io.ReaderAt), size, metaData, progress)
} }
} }
// For any other data size and reader type we do generic multipart return n, err
// approach by staging data in temporary files and uploading them.
return c.putObjectMultipartStream(bucketName, objectName, reader, size, metaData, progress)
} }
// putObjectMultipartStreamNoChecksum - upload a large object using func (c Client) putObjectMultipartNoStream(bucketName, objectName string, reader io.Reader, size int64,
// multipart upload and streaming signature for signing payload. metadata map[string][]string, progress io.Reader) (n int64, err error) {
func (c Client) putObjectMultipartStreamNoChecksum(bucketName, objectName string,
reader io.Reader, size int64, metadata map[string][]string, progress io.Reader) (int64, error) {
// Input validation. // Input validation.
if err := s3utils.CheckValidBucketName(bucketName); err != nil { if err = s3utils.CheckValidBucketName(bucketName); err != nil {
return 0, err return 0, err
} }
if err := s3utils.CheckValidObjectName(objectName); err != nil { if err = s3utils.CheckValidObjectName(objectName); err != nil {
return 0, err return 0, err
} }
// Initiates a new multipart request // Total data read and written to server. should be equal to
uploadID, err := c.newUploadID(bucketName, objectName, metadata) // 'size' at the end of the call.
if err != nil {
return 0, err
}
// Calculate the optimal parts info for a given size.
totalPartsCount, partSize, lastPartSize, err := optimalPartInfo(size)
if err != nil {
return 0, err
}
// Total data read and written to server. should be equal to 'size' at the end of the call.
var totalUploadedSize int64
// Initialize parts uploaded map.
partsInfo := make(map[int]ObjectPart)
// Part number always starts with '1'.
var partNumber int
for partNumber = 1; partNumber <= totalPartsCount; partNumber++ {
// Update progress reader appropriately to the latest offset
// as we read from the source.
hookReader := newHook(reader, progress)
// Proceed to upload the part.
if partNumber == totalPartsCount {
partSize = lastPartSize
}
var objPart ObjectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID,
io.LimitReader(hookReader, partSize), partNumber, nil, nil, partSize)
// For unknown size, Read EOF we break away.
// We do not have to upload till totalPartsCount.
if err == io.EOF && size < 0 {
break
}
if err != nil {
return totalUploadedSize, err
}
// Save successfully uploaded part metadata.
partsInfo[partNumber] = objPart
// Save successfully uploaded size.
totalUploadedSize += partSize
}
// Verify if we uploaded all the data.
if size > 0 {
if totalUploadedSize != size {
return totalUploadedSize, ErrUnexpectedEOF(totalUploadedSize, size, bucketName, objectName)
}
}
// Complete multipart upload.
var complMultipartUpload completeMultipartUpload
// Loop over total uploaded parts to save them in
// Parts array before completing the multipart request.
for i := 1; i < partNumber; i++ {
part, ok := partsInfo[i]
if !ok {
return 0, ErrInvalidArgument(fmt.Sprintf("Missing part number %d", i))
}
complMultipartUpload.Parts = append(complMultipartUpload.Parts, CompletePart{
ETag: part.ETag,
PartNumber: part.PartNumber,
})
}
// Sort all completed parts.
sort.Sort(completedParts(complMultipartUpload.Parts))
_, err = c.completeMultipartUpload(bucketName, objectName, uploadID, complMultipartUpload)
if err != nil {
return totalUploadedSize, err
}
// Return final size.
return totalUploadedSize, nil
}
// putObjectStream uploads files bigger than 64MiB, and also supports
// special case where size is unknown i.e '-1'.
func (c Client) putObjectMultipartStream(bucketName, objectName string, reader io.Reader, size int64, metaData map[string][]string, progress io.Reader) (n int64, err error) {
// Input validation.
if err := s3utils.CheckValidBucketName(bucketName); err != nil {
return 0, err
}
if err := s3utils.CheckValidObjectName(objectName); err != nil {
return 0, err
}
// Total data read and written to server. should be equal to 'size' at the end of the call.
var totalUploadedSize int64 var totalUploadedSize int64
// Complete multipart upload. // Complete multipart upload.
var complMultipartUpload completeMultipartUpload var complMultipartUpload completeMultipartUpload
// Initiate a new multipart upload.
uploadID, err := c.newUploadID(bucketName, objectName, metaData)
if err != nil {
return 0, err
}
// Calculate the optimal parts info for a given size. // Calculate the optimal parts info for a given size.
totalPartsCount, partSize, _, err := optimalPartInfo(size) totalPartsCount, partSize, _, err := optimalPartInfo(size)
if err != nil { if err != nil {
return 0, err return 0, err
} }
// Initiate a new multipart upload.
uploadID, err := c.newUploadID(bucketName, objectName, metadata)
if err != nil {
return 0, err
}
defer func() {
if err != nil {
c.abortMultipartUpload(bucketName, objectName, uploadID)
}
}()
// Part number always starts with '1'. // Part number always starts with '1'.
partNumber := 1 partNumber := 1
@ -197,8 +95,9 @@ func (c Client) putObjectMultipartStream(bucketName, objectName string, reader i
partsInfo := make(map[int]ObjectPart) partsInfo := make(map[int]ObjectPart)
for partNumber <= totalPartsCount { for partNumber <= totalPartsCount {
// Choose hash algorithms to be calculated by hashCopyN, avoid sha256 // Choose hash algorithms to be calculated by hashCopyN,
// with non-v4 signature request or HTTPS connection // avoid sha256 with non-v4 signature request or
// HTTPS connection.
hashAlgos, hashSums := c.hashMaterials() hashAlgos, hashSums := c.hashMaterials()
// Calculates hash sums while copying partSize bytes into tmpBuffer. // Calculates hash sums while copying partSize bytes into tmpBuffer.
@ -214,7 +113,8 @@ func (c Client) putObjectMultipartStream(bucketName, objectName string, reader i
// Proceed to upload the part. // Proceed to upload the part.
var objPart ObjectPart var objPart ObjectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID, reader, partNumber, hashSums["md5"], hashSums["sha256"], prtSize) objPart, err = c.uploadPart(bucketName, objectName, uploadID, reader, partNumber,
hashSums["md5"], hashSums["sha256"], prtSize, metadata)
if err != nil { if err != nil {
// Reset the temporary buffer upon any error. // Reset the temporary buffer upon any error.
tmpBuffer.Reset() tmpBuffer.Reset()
@ -224,13 +124,6 @@ func (c Client) putObjectMultipartStream(bucketName, objectName string, reader i
// Save successfully uploaded part metadata. // Save successfully uploaded part metadata.
partsInfo[partNumber] = objPart partsInfo[partNumber] = objPart
// Update the progress reader for the skipped part.
if progress != nil {
if _, err = io.CopyN(ioutil.Discard, progress, prtSize); err != nil {
return totalUploadedSize, err
}
}
// Reset the temporary buffer. // Reset the temporary buffer.
tmpBuffer.Reset() tmpBuffer.Reset()
@ -269,8 +162,7 @@ func (c Client) putObjectMultipartStream(bucketName, objectName string, reader i
// Sort all completed parts. // Sort all completed parts.
sort.Sort(completedParts(complMultipartUpload.Parts)) sort.Sort(completedParts(complMultipartUpload.Parts))
_, err = c.completeMultipartUpload(bucketName, objectName, uploadID, complMultipartUpload) if _, err = c.completeMultipartUpload(bucketName, objectName, uploadID, complMultipartUpload); err != nil {
if err != nil {
return totalUploadedSize, err return totalUploadedSize, err
} }
@ -279,7 +171,7 @@ func (c Client) putObjectMultipartStream(bucketName, objectName string, reader i
} }
// initiateMultipartUpload - Initiates a multipart upload and returns an upload ID. // initiateMultipartUpload - Initiates a multipart upload and returns an upload ID.
func (c Client) initiateMultipartUpload(bucketName, objectName string, metaData map[string][]string) (initiateMultipartUploadResult, error) { func (c Client) initiateMultipartUpload(bucketName, objectName string, metadata map[string][]string) (initiateMultipartUploadResult, error) {
// Input validation. // Input validation.
if err := s3utils.CheckValidBucketName(bucketName); err != nil { if err := s3utils.CheckValidBucketName(bucketName); err != nil {
return initiateMultipartUploadResult{}, err return initiateMultipartUploadResult{}, err
@ -294,14 +186,14 @@ func (c Client) initiateMultipartUpload(bucketName, objectName string, metaData
// Set ContentType header. // Set ContentType header.
customHeader := make(http.Header) customHeader := make(http.Header)
for k, v := range metaData { for k, v := range metadata {
if len(v) > 0 { if len(v) > 0 {
customHeader.Set(k, v[0]) customHeader.Set(k, v[0])
} }
} }
// Set a default content-type header if the latter is not provided // Set a default content-type header if the latter is not provided
if v, ok := metaData["Content-Type"]; !ok || len(v) == 0 { if v, ok := metadata["Content-Type"]; !ok || len(v) == 0 {
customHeader.Set("Content-Type", "application/octet-stream") customHeader.Set("Content-Type", "application/octet-stream")
} }
@ -332,8 +224,11 @@ func (c Client) initiateMultipartUpload(bucketName, objectName string, metaData
return initiateMultipartUploadResult, nil return initiateMultipartUploadResult, nil
} }
const serverEncryptionKeyPrefix = "x-amz-server-side-encryption"
// uploadPart - Uploads a part in a multipart upload. // uploadPart - Uploads a part in a multipart upload.
func (c Client) uploadPart(bucketName, objectName, uploadID string, reader io.Reader, partNumber int, md5Sum, sha256Sum []byte, size int64) (ObjectPart, error) { func (c Client) uploadPart(bucketName, objectName, uploadID string, reader io.Reader,
partNumber int, md5Sum, sha256Sum []byte, size int64, metadata map[string][]string) (ObjectPart, error) {
// Input validation. // Input validation.
if err := s3utils.CheckValidBucketName(bucketName); err != nil { if err := s3utils.CheckValidBucketName(bucketName); err != nil {
return ObjectPart{}, err return ObjectPart{}, err
@ -361,10 +256,21 @@ func (c Client) uploadPart(bucketName, objectName, uploadID string, reader io.Re
// Set upload id. // Set upload id.
urlValues.Set("uploadId", uploadID) urlValues.Set("uploadId", uploadID)
// Set encryption headers, if any.
customHeader := make(http.Header)
for k, v := range metadata {
if len(v) > 0 {
if strings.HasPrefix(strings.ToLower(k), serverEncryptionKeyPrefix) {
customHeader.Set(k, v[0])
}
}
}
reqMetadata := requestMetadata{ reqMetadata := requestMetadata{
bucketName: bucketName, bucketName: bucketName,
objectName: objectName, objectName: objectName,
queryValues: urlValues, queryValues: urlValues,
customHeader: customHeader,
contentBody: reader, contentBody: reader,
contentLength: size, contentLength: size,
contentMD5Bytes: md5Sum, contentMD5Bytes: md5Sum,
@ -393,7 +299,8 @@ func (c Client) uploadPart(bucketName, objectName, uploadID string, reader io.Re
} }
// completeMultipartUpload - Completes a multipart upload by assembling previously uploaded parts. // completeMultipartUpload - Completes a multipart upload by assembling previously uploaded parts.
func (c Client) completeMultipartUpload(bucketName, objectName, uploadID string, complete completeMultipartUpload) (completeMultipartUploadResult, error) { func (c Client) completeMultipartUpload(bucketName, objectName, uploadID string,
complete completeMultipartUpload) (completeMultipartUploadResult, error) {
// Input validation. // Input validation.
if err := s3utils.CheckValidBucketName(bucketName); err != nil { if err := s3utils.CheckValidBucketName(bucketName); err != nil {
return completeMultipartUploadResult{}, err return completeMultipartUploadResult{}, err

View File

@ -1,191 +0,0 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2015 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package minio
import (
"io"
"strings"
"github.com/minio/minio-go/pkg/credentials"
"github.com/minio/minio-go/pkg/encrypt"
"github.com/minio/minio-go/pkg/s3utils"
)
// PutObjectWithProgress - with progress.
func (c Client) PutObjectWithProgress(bucketName, objectName string, reader io.Reader, contentType string, progress io.Reader) (n int64, err error) {
metaData := make(map[string][]string)
metaData["Content-Type"] = []string{contentType}
return c.PutObjectWithMetadata(bucketName, objectName, reader, metaData, progress)
}
// PutEncryptedObject - Encrypt and store object.
func (c Client) PutEncryptedObject(bucketName, objectName string, reader io.Reader, encryptMaterials encrypt.Materials, metaData map[string][]string, progress io.Reader) (n int64, err error) {
if encryptMaterials == nil {
return 0, ErrInvalidArgument("Unable to recognize empty encryption properties")
}
if err := encryptMaterials.SetupEncryptMode(reader); err != nil {
return 0, err
}
if metaData == nil {
metaData = make(map[string][]string)
}
// Set the necessary encryption headers, for future decryption.
metaData[amzHeaderIV] = []string{encryptMaterials.GetIV()}
metaData[amzHeaderKey] = []string{encryptMaterials.GetKey()}
metaData[amzHeaderMatDesc] = []string{encryptMaterials.GetDesc()}
return c.PutObjectWithMetadata(bucketName, objectName, encryptMaterials, metaData, progress)
}
// PutObjectWithMetadata - with metadata.
func (c Client) PutObjectWithMetadata(bucketName, objectName string, reader io.Reader, metaData map[string][]string, progress io.Reader) (n int64, err error) {
// Input validation.
if err := s3utils.CheckValidBucketName(bucketName); err != nil {
return 0, err
}
if err := s3utils.CheckValidObjectName(objectName); err != nil {
return 0, err
}
if reader == nil {
return 0, ErrInvalidArgument("Input reader is invalid, cannot be nil.")
}
// Size of the object.
var size int64
// Get reader size.
size, err = getReaderSize(reader)
if err != nil {
return 0, err
}
// Check for largest object size allowed.
if size > int64(maxMultipartPutObjectSize) {
return 0, ErrEntityTooLarge(size, maxMultipartPutObjectSize, bucketName, objectName)
}
// NOTE: Google Cloud Storage does not implement Amazon S3 Compatible multipart PUT.
if s3utils.IsGoogleEndpoint(c.endpointURL) {
// Do not compute MD5 for Google Cloud Storage.
return c.putObjectNoChecksum(bucketName, objectName, reader, size, metaData, progress)
}
// putSmall object.
if size < minPartSize && size >= 0 {
return c.putObjectSingle(bucketName, objectName, reader, size, metaData, progress)
}
// For all sizes greater than 5MiB do multipart.
n, err = c.putObjectMultipart(bucketName, objectName, reader, size, metaData, progress)
if err != nil {
errResp := ToErrorResponse(err)
// Verify if multipart functionality is not available, if not
// fall back to single PutObject operation.
if errResp.Code == "AccessDenied" && strings.Contains(errResp.Message, "Access Denied") {
// Verify if size of reader is greater than '5GiB'.
if size > maxSinglePutObjectSize {
return 0, ErrEntityTooLarge(size, maxSinglePutObjectSize, bucketName, objectName)
}
// Fall back to uploading as single PutObject operation.
return c.putObjectSingle(bucketName, objectName, reader, size, metaData, progress)
}
return n, err
}
return n, nil
}
// PutObjectStreaming using AWS streaming signature V4
func (c Client) PutObjectStreaming(bucketName, objectName string, reader io.Reader) (n int64, err error) {
return c.PutObjectStreamingWithProgress(bucketName, objectName, reader, nil, nil)
}
// PutObjectStreamingWithMetadata using AWS streaming signature V4
func (c Client) PutObjectStreamingWithMetadata(bucketName, objectName string, reader io.Reader, metadata map[string][]string) (n int64, err error) {
return c.PutObjectStreamingWithProgress(bucketName, objectName, reader, metadata, nil)
}
// PutObjectStreamingWithProgress using AWS streaming signature V4
func (c Client) PutObjectStreamingWithProgress(bucketName, objectName string, reader io.Reader, metadata map[string][]string, progress io.Reader) (n int64, err error) {
// NOTE: Streaming signature is not supported by GCS.
if s3utils.IsGoogleEndpoint(c.endpointURL) {
return 0, ErrorResponse{
Code: "NotImplemented",
Message: "AWS streaming signature v4 is not supported with Google Cloud Storage",
Key: objectName,
BucketName: bucketName,
}
}
if c.overrideSignerType.IsV2() {
return 0, ErrorResponse{
Code: "NotImplemented",
Message: "AWS streaming signature v4 is not supported with minio client initialized for AWS signature v2",
Key: objectName,
BucketName: bucketName,
}
}
// Size of the object.
var size int64
// Get reader size.
size, err = getReaderSize(reader)
if err != nil {
return 0, err
}
// Check for largest object size allowed.
if size > int64(maxMultipartPutObjectSize) {
return 0, ErrEntityTooLarge(size, maxMultipartPutObjectSize, bucketName, objectName)
}
// If size cannot be found on a stream, it is not possible
// to upload using streaming signature, fall back to multipart.
if size < 0 {
return c.putObjectMultipartStream(bucketName, objectName, reader, size, metadata, progress)
}
// Set streaming signature.
c.overrideSignerType = credentials.SignatureV4Streaming
if size < minPartSize && size >= 0 {
return c.putObjectNoChecksum(bucketName, objectName, reader, size, metadata, progress)
}
// For all sizes greater than 64MiB do multipart.
n, err = c.putObjectMultipartStreamNoChecksum(bucketName, objectName, reader, size, metadata, progress)
if err != nil {
errResp := ToErrorResponse(err)
// Verify if multipart functionality is not available, if not
// fall back to single PutObject operation.
if errResp.Code == "AccessDenied" && strings.Contains(errResp.Message, "Access Denied") {
// Verify if size of reader is greater than '5GiB'.
if size > maxSinglePutObjectSize {
return 0, ErrEntityTooLarge(size, maxSinglePutObjectSize, bucketName, objectName)
}
// Fall back to uploading as single PutObject operation.
return c.putObjectNoChecksum(bucketName, objectName, reader, size, metadata, progress)
}
return n, err
}
return n, nil
}

View File

@ -1,219 +0,0 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2015, 2016 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package minio
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"sort"
"github.com/minio/minio-go/pkg/s3utils"
)
// uploadedPartRes - the response received from a part upload.
type uploadedPartRes struct {
Error error // Any error encountered while uploading the part.
PartNum int // Number of the part uploaded.
Size int64 // Size of the part uploaded.
Part *ObjectPart
}
type uploadPartReq struct {
PartNum int // Number of the part uploaded.
Part *ObjectPart // Size of the part uploaded.
}
// putObjectMultipartFromReadAt - Uploads files bigger than 5MiB. Supports reader
// of type which implements io.ReaderAt interface (ReadAt method).
//
// NOTE: This function is meant to be used for all readers which
// implement io.ReaderAt which allows us for resuming multipart
// uploads but reading at an offset, which would avoid re-read the
// data which was already uploaded. Internally this function uses
// temporary files for staging all the data, these temporary files are
// cleaned automatically when the caller i.e http client closes the
// stream after uploading all the contents successfully.
func (c Client) putObjectMultipartFromReadAt(bucketName, objectName string, reader io.ReaderAt, size int64, metaData map[string][]string, progress io.Reader) (n int64, err error) {
// Input validation.
if err := s3utils.CheckValidBucketName(bucketName); err != nil {
return 0, err
}
if err := s3utils.CheckValidObjectName(objectName); err != nil {
return 0, err
}
// Initiate a new multipart upload.
uploadID, err := c.newUploadID(bucketName, objectName, metaData)
if err != nil {
return 0, err
}
// Total data read and written to server. should be equal to 'size' at the end of the call.
var totalUploadedSize int64
// Complete multipart upload.
var complMultipartUpload completeMultipartUpload
// Calculate the optimal parts info for a given size.
totalPartsCount, partSize, lastPartSize, err := optimalPartInfo(size)
if err != nil {
return 0, err
}
// Declare a channel that sends the next part number to be uploaded.
// Buffered to 10000 because thats the maximum number of parts allowed
// by S3.
uploadPartsCh := make(chan uploadPartReq, 10000)
// Declare a channel that sends back the response of a part upload.
// Buffered to 10000 because thats the maximum number of parts allowed
// by S3.
uploadedPartsCh := make(chan uploadedPartRes, 10000)
// Used for readability, lastPartNumber is always totalPartsCount.
lastPartNumber := totalPartsCount
// Initialize parts uploaded map.
partsInfo := make(map[int]ObjectPart)
// Send each part number to the channel to be processed.
for p := 1; p <= totalPartsCount; p++ {
part, ok := partsInfo[p]
if ok {
uploadPartsCh <- uploadPartReq{PartNum: p, Part: &part}
} else {
uploadPartsCh <- uploadPartReq{PartNum: p, Part: nil}
}
}
close(uploadPartsCh)
// Receive each part number from the channel allowing three parallel uploads.
for w := 1; w <= totalWorkers; w++ {
go func() {
// Read defaults to reading at 5MiB buffer.
readAtBuffer := make([]byte, optimalReadBufferSize)
// Each worker will draw from the part channel and upload in parallel.
for uploadReq := range uploadPartsCh {
// Declare a new tmpBuffer.
tmpBuffer := new(bytes.Buffer)
// If partNumber was not uploaded we calculate the missing
// part offset and size. For all other part numbers we
// calculate offset based on multiples of partSize.
readOffset := int64(uploadReq.PartNum-1) * partSize
missingPartSize := partSize
// As a special case if partNumber is lastPartNumber, we
// calculate the offset based on the last part size.
if uploadReq.PartNum == lastPartNumber {
readOffset = (size - lastPartSize)
missingPartSize = lastPartSize
}
// Get a section reader on a particular offset.
sectionReader := io.NewSectionReader(reader, readOffset, missingPartSize)
// Choose the needed hash algorithms to be calculated by hashCopyBuffer.
// Sha256 is avoided in non-v4 signature requests or HTTPS connections
hashAlgos, hashSums := c.hashMaterials()
var prtSize int64
var err error
prtSize, err = hashCopyBuffer(hashAlgos, hashSums, tmpBuffer, sectionReader, readAtBuffer)
if err != nil {
// Send the error back through the channel.
uploadedPartsCh <- uploadedPartRes{
Size: 0,
Error: err,
}
// Exit the goroutine.
return
}
// Proceed to upload the part.
var objPart ObjectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID, tmpBuffer,
uploadReq.PartNum, hashSums["md5"], hashSums["sha256"], prtSize)
if err != nil {
uploadedPartsCh <- uploadedPartRes{
Size: 0,
Error: err,
}
// Exit the goroutine.
return
}
// Save successfully uploaded part metadata.
uploadReq.Part = &objPart
// Send successful part info through the channel.
uploadedPartsCh <- uploadedPartRes{
Size: missingPartSize,
PartNum: uploadReq.PartNum,
Part: uploadReq.Part,
Error: nil,
}
}
}()
}
// Gather the responses as they occur and update any
// progress bar.
for u := 1; u <= totalPartsCount; u++ {
uploadRes := <-uploadedPartsCh
if uploadRes.Error != nil {
return totalUploadedSize, uploadRes.Error
}
// Retrieve each uploaded part and store it to be completed.
// part, ok := partsInfo[uploadRes.PartNum]
part := uploadRes.Part
if part == nil {
return 0, ErrInvalidArgument(fmt.Sprintf("Missing part number %d", uploadRes.PartNum))
}
// Update the totalUploadedSize.
totalUploadedSize += uploadRes.Size
// Update the progress bar if there is one.
if progress != nil {
if _, err = io.CopyN(ioutil.Discard, progress, uploadRes.Size); err != nil {
return totalUploadedSize, err
}
}
// Store the parts to be completed in order.
complMultipartUpload.Parts = append(complMultipartUpload.Parts, CompletePart{
ETag: part.ETag,
PartNumber: part.PartNumber,
})
}
// Verify if we uploaded all the data.
if totalUploadedSize != size {
return totalUploadedSize, ErrUnexpectedEOF(totalUploadedSize, size, bucketName, objectName)
}
// Sort all completed parts.
sort.Sort(completedParts(complMultipartUpload.Parts))
_, err = c.completeMultipartUpload(bucketName, objectName, uploadID, complMultipartUpload)
if err != nil {
return totalUploadedSize, err
}
// Return final size.
return totalUploadedSize, nil
}

View File

@ -0,0 +1,436 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package minio
import (
"fmt"
"io"
"net/http"
"sort"
"strings"
"github.com/minio/minio-go/pkg/s3utils"
)
// PutObjectStreaming using AWS streaming signature V4
func (c Client) PutObjectStreaming(bucketName, objectName string, reader io.Reader) (n int64, err error) {
return c.PutObjectWithProgress(bucketName, objectName, reader, nil, nil)
}
// putObjectMultipartStream - upload a large object using
// multipart upload and streaming signature for signing payload.
// Comprehensive put object operation involving multipart uploads.
//
// Following code handles these types of readers.
//
// - *os.File
// - *minio.Object
// - Any reader which has a method 'ReadAt()'
//
func (c Client) putObjectMultipartStream(bucketName, objectName string,
reader io.Reader, size int64, metadata map[string][]string, progress io.Reader) (n int64, err error) {
// Verify if reader is *minio.Object, *os.File or io.ReaderAt.
// NOTE: Verification of object is kept for a specific purpose
// while it is going to be duck typed similar to io.ReaderAt.
// It is to indicate that *minio.Object implements io.ReaderAt.
// and such a functionality is used in the subsequent code path.
if isFile(reader) || !isObject(reader) && isReadAt(reader) {
n, err = c.putObjectMultipartStreamFromReadAt(bucketName, objectName, reader.(io.ReaderAt), size, metadata, progress)
} else {
n, err = c.putObjectMultipartStreamNoChecksum(bucketName, objectName, reader, size, metadata, progress)
}
if err != nil {
errResp := ToErrorResponse(err)
// Verify if multipart functionality is not available, if not
// fall back to single PutObject operation.
if errResp.Code == "AccessDenied" && strings.Contains(errResp.Message, "Access Denied") {
// Verify if size of reader is greater than '5GiB'.
if size > maxSinglePutObjectSize {
return 0, ErrEntityTooLarge(size, maxSinglePutObjectSize, bucketName, objectName)
}
// Fall back to uploading as single PutObject operation.
return c.putObjectNoChecksum(bucketName, objectName, reader, size, metadata, progress)
}
}
return n, err
}
// uploadedPartRes - the response received from a part upload.
type uploadedPartRes struct {
Error error // Any error encountered while uploading the part.
PartNum int // Number of the part uploaded.
Size int64 // Size of the part uploaded.
Part *ObjectPart
}
type uploadPartReq struct {
PartNum int // Number of the part uploaded.
Part *ObjectPart // Size of the part uploaded.
}
// putObjectMultipartFromReadAt - Uploads files bigger than 64MiB.
// Supports all readers which implements io.ReaderAt interface
// (ReadAt method).
//
// NOTE: This function is meant to be used for all readers which
// implement io.ReaderAt which allows us for resuming multipart
// uploads but reading at an offset, which would avoid re-read the
// data which was already uploaded. Internally this function uses
// temporary files for staging all the data, these temporary files are
// cleaned automatically when the caller i.e http client closes the
// stream after uploading all the contents successfully.
func (c Client) putObjectMultipartStreamFromReadAt(bucketName, objectName string,
reader io.ReaderAt, size int64, metadata map[string][]string, progress io.Reader) (n int64, err error) {
// Input validation.
if err = s3utils.CheckValidBucketName(bucketName); err != nil {
return 0, err
}
if err = s3utils.CheckValidObjectName(objectName); err != nil {
return 0, err
}
// Calculate the optimal parts info for a given size.
totalPartsCount, partSize, lastPartSize, err := optimalPartInfo(size)
if err != nil {
return 0, err
}
// Initiate a new multipart upload.
uploadID, err := c.newUploadID(bucketName, objectName, metadata)
if err != nil {
return 0, err
}
// Aborts the multipart upload in progress, if the
// function returns any error, since we do not resume
// we should purge the parts which have been uploaded
// to relinquish storage space.
defer func() {
if err != nil {
c.abortMultipartUpload(bucketName, objectName, uploadID)
}
}()
// Total data read and written to server. should be equal to 'size' at the end of the call.
var totalUploadedSize int64
// Complete multipart upload.
var complMultipartUpload completeMultipartUpload
// Declare a channel that sends the next part number to be uploaded.
// Buffered to 10000 because thats the maximum number of parts allowed
// by S3.
uploadPartsCh := make(chan uploadPartReq, 10000)
// Declare a channel that sends back the response of a part upload.
// Buffered to 10000 because thats the maximum number of parts allowed
// by S3.
uploadedPartsCh := make(chan uploadedPartRes, 10000)
// Used for readability, lastPartNumber is always totalPartsCount.
lastPartNumber := totalPartsCount
// Send each part number to the channel to be processed.
for p := 1; p <= totalPartsCount; p++ {
uploadPartsCh <- uploadPartReq{PartNum: p, Part: nil}
}
close(uploadPartsCh)
// Receive each part number from the channel allowing three parallel uploads.
for w := 1; w <= totalWorkers; w++ {
go func() {
// Each worker will draw from the part channel and upload in parallel.
for uploadReq := range uploadPartsCh {
// If partNumber was not uploaded we calculate the missing
// part offset and size. For all other part numbers we
// calculate offset based on multiples of partSize.
readOffset := int64(uploadReq.PartNum-1) * partSize
// As a special case if partNumber is lastPartNumber, we
// calculate the offset based on the last part size.
if uploadReq.PartNum == lastPartNumber {
readOffset = (size - lastPartSize)
partSize = lastPartSize
}
// Get a section reader on a particular offset.
sectionReader := newHook(io.NewSectionReader(reader, readOffset, partSize), progress)
// Proceed to upload the part.
var objPart ObjectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID,
sectionReader, uploadReq.PartNum,
nil, nil, partSize, metadata)
if err != nil {
uploadedPartsCh <- uploadedPartRes{
Size: 0,
Error: err,
}
// Exit the goroutine.
return
}
// Save successfully uploaded part metadata.
uploadReq.Part = &objPart
// Send successful part info through the channel.
uploadedPartsCh <- uploadedPartRes{
Size: objPart.Size,
PartNum: uploadReq.PartNum,
Part: uploadReq.Part,
Error: nil,
}
}
}()
}
// Gather the responses as they occur and update any
// progress bar.
for u := 1; u <= totalPartsCount; u++ {
uploadRes := <-uploadedPartsCh
if uploadRes.Error != nil {
return totalUploadedSize, uploadRes.Error
}
// Retrieve each uploaded part and store it to be completed.
// part, ok := partsInfo[uploadRes.PartNum]
part := uploadRes.Part
if part == nil {
return 0, ErrInvalidArgument(fmt.Sprintf("Missing part number %d", uploadRes.PartNum))
}
// Update the totalUploadedSize.
totalUploadedSize += uploadRes.Size
// Store the parts to be completed in order.
complMultipartUpload.Parts = append(complMultipartUpload.Parts, CompletePart{
ETag: part.ETag,
PartNumber: part.PartNumber,
})
}
// Verify if we uploaded all the data.
if totalUploadedSize != size {
return totalUploadedSize, ErrUnexpectedEOF(totalUploadedSize, size, bucketName, objectName)
}
// Sort all completed parts.
sort.Sort(completedParts(complMultipartUpload.Parts))
_, err = c.completeMultipartUpload(bucketName, objectName, uploadID, complMultipartUpload)
if err != nil {
return totalUploadedSize, err
}
// Return final size.
return totalUploadedSize, nil
}
func (c Client) putObjectMultipartStreamNoChecksum(bucketName, objectName string,
reader io.Reader, size int64, metadata map[string][]string, progress io.Reader) (n int64, err error) {
// Input validation.
if err = s3utils.CheckValidBucketName(bucketName); err != nil {
return 0, err
}
if err = s3utils.CheckValidObjectName(objectName); err != nil {
return 0, err
}
// Calculate the optimal parts info for a given size.
totalPartsCount, partSize, lastPartSize, err := optimalPartInfo(size)
if err != nil {
return 0, err
}
// Initiates a new multipart request
uploadID, err := c.newUploadID(bucketName, objectName, metadata)
if err != nil {
return 0, err
}
// Aborts the multipart upload if the function returns
// any error, since we do not resume we should purge
// the parts which have been uploaded to relinquish
// storage space.
defer func() {
if err != nil {
c.abortMultipartUpload(bucketName, objectName, uploadID)
}
}()
// Total data read and written to server. should be equal to 'size' at the end of the call.
var totalUploadedSize int64
// Initialize parts uploaded map.
partsInfo := make(map[int]ObjectPart)
// Part number always starts with '1'.
var partNumber int
for partNumber = 1; partNumber <= totalPartsCount; partNumber++ {
// Update progress reader appropriately to the latest offset
// as we read from the source.
hookReader := newHook(reader, progress)
// Proceed to upload the part.
if partNumber == totalPartsCount {
partSize = lastPartSize
}
var objPart ObjectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID,
io.LimitReader(hookReader, partSize),
partNumber, nil, nil, partSize, metadata)
if err != nil {
return totalUploadedSize, err
}
// Save successfully uploaded part metadata.
partsInfo[partNumber] = objPart
// Save successfully uploaded size.
totalUploadedSize += partSize
}
// Verify if we uploaded all the data.
if size > 0 {
if totalUploadedSize != size {
return totalUploadedSize, ErrUnexpectedEOF(totalUploadedSize, size, bucketName, objectName)
}
}
// Complete multipart upload.
var complMultipartUpload completeMultipartUpload
// Loop over total uploaded parts to save them in
// Parts array before completing the multipart request.
for i := 1; i < partNumber; i++ {
part, ok := partsInfo[i]
if !ok {
return 0, ErrInvalidArgument(fmt.Sprintf("Missing part number %d", i))
}
complMultipartUpload.Parts = append(complMultipartUpload.Parts, CompletePart{
ETag: part.ETag,
PartNumber: part.PartNumber,
})
}
// Sort all completed parts.
sort.Sort(completedParts(complMultipartUpload.Parts))
_, err = c.completeMultipartUpload(bucketName, objectName, uploadID, complMultipartUpload)
if err != nil {
return totalUploadedSize, err
}
// Return final size.
return totalUploadedSize, nil
}
// putObjectNoChecksum special function used Google Cloud Storage. This special function
// is used for Google Cloud Storage since Google's multipart API is not S3 compatible.
func (c Client) putObjectNoChecksum(bucketName, objectName string, reader io.Reader, size int64, metaData map[string][]string, progress io.Reader) (n int64, err error) {
// Input validation.
if err := s3utils.CheckValidBucketName(bucketName); err != nil {
return 0, err
}
if err := s3utils.CheckValidObjectName(objectName); err != nil {
return 0, err
}
// Size -1 is only supported on Google Cloud Storage, we error
// out in all other situations.
if size < 0 && !s3utils.IsGoogleEndpoint(c.endpointURL) {
return 0, ErrEntityTooSmall(size, bucketName, objectName)
}
if size > 0 {
if isReadAt(reader) && !isObject(reader) {
reader = io.NewSectionReader(reader.(io.ReaderAt), 0, size)
}
}
// Update progress reader appropriately to the latest offset as we
// read from the source.
readSeeker := newHook(reader, progress)
// This function does not calculate sha256 and md5sum for payload.
// Execute put object.
st, err := c.putObjectDo(bucketName, objectName, readSeeker, nil, nil, size, metaData)
if err != nil {
return 0, err
}
if st.Size != size {
return 0, ErrUnexpectedEOF(st.Size, size, bucketName, objectName)
}
return size, nil
}
// putObjectDo - executes the put object http operation.
// NOTE: You must have WRITE permissions on a bucket to add an object to it.
func (c Client) putObjectDo(bucketName, objectName string, reader io.Reader, md5Sum []byte, sha256Sum []byte, size int64, metaData map[string][]string) (ObjectInfo, error) {
// Input validation.
if err := s3utils.CheckValidBucketName(bucketName); err != nil {
return ObjectInfo{}, err
}
if err := s3utils.CheckValidObjectName(objectName); err != nil {
return ObjectInfo{}, err
}
// Set headers.
customHeader := make(http.Header)
// Set metadata to headers
for k, v := range metaData {
if len(v) > 0 {
customHeader.Set(k, v[0])
}
}
// If Content-Type is not provided, set the default application/octet-stream one
if v, ok := metaData["Content-Type"]; !ok || len(v) == 0 {
customHeader.Set("Content-Type", "application/octet-stream")
}
// Populate request metadata.
reqMetadata := requestMetadata{
bucketName: bucketName,
objectName: objectName,
customHeader: customHeader,
contentBody: reader,
contentLength: size,
contentMD5Bytes: md5Sum,
contentSHA256Bytes: sha256Sum,
}
// Execute PUT an objectName.
resp, err := c.executeMethod("PUT", reqMetadata)
defer closeResponse(resp)
if err != nil {
return ObjectInfo{}, err
}
if resp != nil {
if resp.StatusCode != http.StatusOK {
return ObjectInfo{}, httpRespToErrorResponse(resp, bucketName, objectName)
}
}
var objInfo ObjectInfo
// Trim off the odd double quotes from ETag in the beginning and end.
objInfo.ETag = strings.TrimPrefix(resp.Header.Get("ETag"), "\"")
objInfo.ETag = strings.TrimSuffix(objInfo.ETag, "\"")
// A success here means data was written to server successfully.
objInfo.Size = size
// Return here.
return objInfo, nil
}

View File

@ -18,13 +18,12 @@ package minio
import ( import (
"io" "io"
"io/ioutil"
"net/http"
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"github.com/minio/minio-go/pkg/credentials"
"github.com/minio/minio-go/pkg/s3utils" "github.com/minio/minio-go/pkg/s3utils"
) )
@ -143,164 +142,79 @@ func (a completedParts) Less(i, j int) bool { return a[i].PartNumber < a[j].Part
// //
// You must have WRITE permissions on a bucket to create an object. // You must have WRITE permissions on a bucket to create an object.
// //
// - For size smaller than 64MiB PutObject automatically does a single atomic Put operation. // - For size smaller than 64MiB PutObject automatically does a
// - For size larger than 64MiB PutObject automatically does a multipart Put operation. // single atomic Put operation.
// - For size input as -1 PutObject does a multipart Put operation until input stream reaches EOF. // - For size larger than 64MiB PutObject automatically does a
// Maximum object size that can be uploaded through this operation will be 5TiB. // multipart Put operation.
// // - For size input as -1 PutObject does a multipart Put operation
// NOTE: Google Cloud Storage does not implement Amazon S3 Compatible multipart PUT. // until input stream reaches EOF. Maximum object size that can
// So we fall back to single PUT operation with the maximum limit of 5GiB. // be uploaded through this operation will be 5TiB.
func (c Client) PutObject(bucketName, objectName string, reader io.Reader, contentType string) (n int64, err error) { func (c Client) PutObject(bucketName, objectName string, reader io.Reader, contentType string) (n int64, err error) {
return c.PutObjectWithProgress(bucketName, objectName, reader, contentType, nil) return c.PutObjectWithMetadata(bucketName, objectName, reader, map[string][]string{
"Content-Type": []string{contentType},
}, nil)
} }
// putObjectNoChecksum special function used Google Cloud Storage. This special function // PutObjectWithSize - is a helper PutObject similar in behavior to PutObject()
// is used for Google Cloud Storage since Google's multipart API is not S3 compatible. // but takes the size argument explicitly, this function avoids doing reflection
func (c Client) putObjectNoChecksum(bucketName, objectName string, reader io.Reader, size int64, metaData map[string][]string, progress io.Reader) (n int64, err error) { // internally to figure out the size of input stream. Also if the input size is
// Input validation. // lesser than 0 this function returns an error.
if err := s3utils.CheckValidBucketName(bucketName); err != nil { func (c Client) PutObjectWithSize(bucketName, objectName string, reader io.Reader, readerSize int64, metadata map[string][]string, progress io.Reader) (n int64, err error) {
return 0, err return c.putObjectCommon(bucketName, objectName, reader, readerSize, metadata, progress)
} }
if err := s3utils.CheckValidObjectName(objectName); err != nil {
return 0, err
}
if size > 0 {
readerAt, ok := reader.(io.ReaderAt)
if ok {
reader = io.NewSectionReader(readerAt, 0, size)
}
}
// Update progress reader appropriately to the latest offset as we // PutObjectWithMetadata using AWS streaming signature V4
// read from the source. func (c Client) PutObjectWithMetadata(bucketName, objectName string, reader io.Reader, metadata map[string][]string, progress io.Reader) (n int64, err error) {
readSeeker := newHook(reader, progress) return c.PutObjectWithProgress(bucketName, objectName, reader, metadata, progress)
}
// This function does not calculate sha256 and md5sum for payload. // PutObjectWithProgress using AWS streaming signature V4
// Execute put object. func (c Client) PutObjectWithProgress(bucketName, objectName string, reader io.Reader, metadata map[string][]string, progress io.Reader) (n int64, err error) {
st, err := c.putObjectDo(bucketName, objectName, readSeeker, nil, nil, size, metaData) // Size of the object.
var size int64
// Get reader size.
size, err = getReaderSize(reader)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if st.Size != size { return c.putObjectCommon(bucketName, objectName, reader, size, metadata, progress)
return 0, ErrUnexpectedEOF(st.Size, size, bucketName, objectName)
}
return size, nil
} }
// putObjectSingle is a special function for uploading single put object request. func (c Client) putObjectCommon(bucketName, objectName string, reader io.Reader, size int64, metadata map[string][]string, progress io.Reader) (n int64, err error) {
// This special function is used as a fallback when multipart upload fails. // Check for largest object size allowed.
func (c Client) putObjectSingle(bucketName, objectName string, reader io.Reader, size int64, metaData map[string][]string, progress io.Reader) (n int64, err error) { if size > int64(maxMultipartPutObjectSize) {
// Input validation. return 0, ErrEntityTooLarge(size, maxMultipartPutObjectSize, bucketName, objectName)
if err := s3utils.CheckValidBucketName(bucketName); err != nil {
return 0, err
}
if err := s3utils.CheckValidObjectName(objectName); err != nil {
return 0, err
}
if size > maxSinglePutObjectSize {
return 0, ErrEntityTooLarge(size, maxSinglePutObjectSize, bucketName, objectName)
}
// If size is a stream, upload up to 5GiB.
if size <= -1 {
size = maxSinglePutObjectSize
} }
// Add the appropriate hash algorithms that need to be calculated by hashCopyN // NOTE: Streaming signature is not supported by GCS.
// In case of non-v4 signature request or HTTPS connection, sha256 is not needed. if s3utils.IsGoogleEndpoint(c.endpointURL) {
hashAlgos, hashSums := c.hashMaterials() // Do not compute MD5 for Google Cloud Storage.
return c.putObjectNoChecksum(bucketName, objectName, reader, size, metadata, progress)
// Initialize a new temporary file.
tmpFile, err := newTempFile("single$-putobject-single")
if err != nil {
return 0, err
}
defer tmpFile.Close()
size, err = hashCopyN(hashAlgos, hashSums, tmpFile, reader, size)
// Return error if its not io.EOF.
if err != nil && err != io.EOF {
return 0, err
} }
// Seek back to beginning of the temporary file. if c.overrideSignerType.IsV2() {
if _, err = tmpFile.Seek(0, 0); err != nil { if size >= 0 && size < minPartSize {
return 0, err return c.putObjectNoChecksum(bucketName, objectName, reader, size, metadata, progress)
}
reader = tmpFile
// Execute put object.
st, err := c.putObjectDo(bucketName, objectName, reader, hashSums["md5"], hashSums["sha256"], size, metaData)
if err != nil {
return 0, err
}
if st.Size != size {
return 0, ErrUnexpectedEOF(st.Size, size, bucketName, objectName)
}
// Progress the reader to the size if putObjectDo is successful.
if progress != nil {
if _, err = io.CopyN(ioutil.Discard, progress, size); err != nil {
return size, err
} }
return c.putObjectMultipart(bucketName, objectName, reader, size, metadata, progress)
} }
return size, nil
} // If size cannot be found on a stream, it is not possible
// to upload using streaming signature, fall back to multipart.
// putObjectDo - executes the put object http operation. if size < 0 {
// NOTE: You must have WRITE permissions on a bucket to add an object to it. // Set regular signature calculation.
func (c Client) putObjectDo(bucketName, objectName string, reader io.Reader, md5Sum []byte, sha256Sum []byte, size int64, metaData map[string][]string) (ObjectInfo, error) { c.overrideSignerType = credentials.SignatureV4
// Input validation. return c.putObjectMultipart(bucketName, objectName, reader, size, metadata, progress)
if err := s3utils.CheckValidBucketName(bucketName); err != nil { }
return ObjectInfo{}, err
} // Set streaming signature.
if err := s3utils.CheckValidObjectName(objectName); err != nil { c.overrideSignerType = credentials.SignatureV4Streaming
return ObjectInfo{}, err
} if size < minPartSize {
return c.putObjectNoChecksum(bucketName, objectName, reader, size, metadata, progress)
// Set headers. }
customHeader := make(http.Header)
// For all sizes greater than 64MiB do multipart.
// Set metadata to headers return c.putObjectMultipartStream(bucketName, objectName, reader, size, metadata, progress)
for k, v := range metaData {
if len(v) > 0 {
customHeader.Set(k, v[0])
}
}
// If Content-Type is not provided, set the default application/octet-stream one
if v, ok := metaData["Content-Type"]; !ok || len(v) == 0 {
customHeader.Set("Content-Type", "application/octet-stream")
}
// Populate request metadata.
reqMetadata := requestMetadata{
bucketName: bucketName,
objectName: objectName,
customHeader: customHeader,
contentBody: reader,
contentLength: size,
contentMD5Bytes: md5Sum,
contentSHA256Bytes: sha256Sum,
}
// Execute PUT an objectName.
resp, err := c.executeMethod("PUT", reqMetadata)
defer closeResponse(resp)
if err != nil {
return ObjectInfo{}, err
}
if resp != nil {
if resp.StatusCode != http.StatusOK {
return ObjectInfo{}, httpRespToErrorResponse(resp, bucketName, objectName)
}
}
var objInfo ObjectInfo
// Trim off the odd double quotes from ETag in the beginning and end.
objInfo.ETag = strings.TrimPrefix(resp.Header.Get("ETag"), "\"")
objInfo.ETag = strings.TrimSuffix(objInfo.ETag, "\"")
// A success here means data was written to server successfully.
objInfo.Size = size
// Return here.
return objInfo, nil
} }

View File

@ -167,11 +167,6 @@ func (c Client) statObject(bucketName, objectName string, reqHeaders RequestHead
contentType = "application/octet-stream" contentType = "application/octet-stream"
} }
// Extract only the relevant header keys describing the object.
// following function filters out a list of standard set of keys
// which are not part of object metadata.
metadata := extractObjMetadata(resp.Header)
// Save object metadata info. // Save object metadata info.
return ObjectInfo{ return ObjectInfo{
ETag: md5sum, ETag: md5sum,
@ -179,6 +174,9 @@ func (c Client) statObject(bucketName, objectName string, reqHeaders RequestHead
Size: size, Size: size,
LastModified: date, LastModified: date,
ContentType: contentType, ContentType: contentType,
Metadata: metadata, // Extract only the relevant header keys describing the object.
// following function filters out a list of standard set of keys
// which are not part of object metadata.
Metadata: extractObjMetadata(resp.Header),
}, nil }, nil
} }

View File

@ -87,7 +87,7 @@ type Client struct {
// Global constants. // Global constants.
const ( const (
libraryName = "minio-go" libraryName = "minio-go"
libraryVersion = "2.1.0" libraryVersion = "3.0.0"
) )
// User Agent should always following the below style. // User Agent should always following the below style.
@ -211,7 +211,7 @@ func privateNew(endpoint string, creds *credentials.Credentials, secure bool, re
// Instantiate http client and bucket location cache. // Instantiate http client and bucket location cache.
clnt.httpClient = &http.Client{ clnt.httpClient = &http.Client{
Transport: http.DefaultTransport, Transport: defaultMinioTransport,
CheckRedirect: redirectHeaders, CheckRedirect: redirectHeaders,
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -182,27 +182,6 @@ func TestValidBucketLocation(t *testing.T) {
} }
} }
// Tests temp file.
func TestTempFile(t *testing.T) {
tmpFile, err := newTempFile("testing")
if err != nil {
t.Fatal("Error:", err)
}
fileName := tmpFile.Name()
// Closing temporary file purges the file.
err = tmpFile.Close()
if err != nil {
t.Fatal("Error:", err)
}
st, err := os.Stat(fileName)
if err != nil && !os.IsNotExist(err) {
t.Fatal("Error:", err)
}
if err == nil && st != nil {
t.Fatal("Error: file should be deleted and should not exist.")
}
}
// Tests error response structure. // Tests error response structure.
func TestErrorResponse(t *testing.T) { func TestErrorResponse(t *testing.T) {
var err error var err error

View File

@ -213,20 +213,24 @@ func (c Client) getBucketLocationRequest(bucketName string) (*http.Request, erro
signerType = credentials.SignatureAnonymous signerType = credentials.SignatureAnonymous
} }
// Set sha256 sum for signature calculation only with signature version '4'. if signerType.IsAnonymous() {
switch { return req, nil
case signerType.IsV4():
var contentSha256 string
if c.secure {
contentSha256 = unsignedPayload
} else {
contentSha256 = hex.EncodeToString(sum256([]byte{}))
}
req.Header.Set("X-Amz-Content-Sha256", contentSha256)
req = s3signer.SignV4(*req, accessKeyID, secretAccessKey, sessionToken, "us-east-1")
case signerType.IsV2():
req = s3signer.SignV2(*req, accessKeyID, secretAccessKey)
} }
if signerType.IsV2() {
req = s3signer.SignV2(*req, accessKeyID, secretAccessKey)
return req, nil
}
// Set sha256 sum for signature calculation only with signature version '4'.
var contentSha256 string
if c.secure {
contentSha256 = unsignedPayload
} else {
contentSha256 = hex.EncodeToString(sum256([]byte{}))
}
req.Header.Set("X-Amz-Content-Sha256", contentSha256)
req = s3signer.SignV4(*req, accessKeyID, secretAccessKey, sessionToken, "us-east-1")
return req, nil return req, nil
} }

View File

@ -18,10 +18,18 @@ package minio
/// Multipart upload defaults. /// Multipart upload defaults.
// miniPartSize - minimum part size 64MiB per object after which // absMinPartSize - absolute minimum part size (5 MiB) below which
// a part in a multipart upload may not be uploaded.
const absMinPartSize = 1024 * 1024 * 5
// minPartSize - minimum part size 64MiB per object after which
// putObject behaves internally as multipart. // putObject behaves internally as multipart.
const minPartSize = 1024 * 1024 * 64 const minPartSize = 1024 * 1024 * 64
// copyPartSize - default (and maximum) part size to copy in a
// copy-object request (5GiB)
const copyPartSize = 1024 * 1024 * 1024 * 5
// maxPartsCount - maximum number of parts for a single multipart session. // maxPartsCount - maximum number of parts for a single multipart session.
const maxPartsCount = 10000 const maxPartsCount = 10000
@ -37,10 +45,6 @@ const maxSinglePutObjectSize = 1024 * 1024 * 1024 * 5
// Multipart operation. // Multipart operation.
const maxMultipartPutObjectSize = 1024 * 1024 * 1024 * 1024 * 5 const maxMultipartPutObjectSize = 1024 * 1024 * 1024 * 1024 * 5
// optimalReadBufferSize - optimal buffer 5MiB used for reading
// through Read operation.
const optimalReadBufferSize = 1024 * 1024 * 5
// unsignedPayload - value to be set to X-Amz-Content-Sha256 header when // unsignedPayload - value to be set to X-Amz-Content-Sha256 header when
// we don't want to sign the request payload // we don't want to sign the request payload
const unsignedPayload = "UNSIGNED-PAYLOAD" const unsignedPayload = "UNSIGNED-PAYLOAD"

View File

@ -1,99 +0,0 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2016 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package minio
import (
"net/http"
"time"
)
// copyCondition explanation:
// http://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectCOPY.html
//
// Example:
//
// copyCondition {
// key: "x-amz-copy-if-modified-since",
// value: "Tue, 15 Nov 1994 12:45:26 GMT",
// }
//
type copyCondition struct {
key string
value string
}
// CopyConditions - copy conditions.
type CopyConditions struct {
conditions []copyCondition
}
// NewCopyConditions - Instantiate new list of conditions. This
// function is left behind for backward compatibility. The idiomatic
// way to set an empty set of copy conditions is,
// ``copyConditions := CopyConditions{}``.
//
func NewCopyConditions() CopyConditions {
return CopyConditions{}
}
// SetMatchETag - set match etag.
func (c *CopyConditions) SetMatchETag(etag string) error {
if etag == "" {
return ErrInvalidArgument("ETag cannot be empty.")
}
c.conditions = append(c.conditions, copyCondition{
key: "x-amz-copy-source-if-match",
value: etag,
})
return nil
}
// SetMatchETagExcept - set match etag except.
func (c *CopyConditions) SetMatchETagExcept(etag string) error {
if etag == "" {
return ErrInvalidArgument("ETag cannot be empty.")
}
c.conditions = append(c.conditions, copyCondition{
key: "x-amz-copy-source-if-none-match",
value: etag,
})
return nil
}
// SetUnmodified - set unmodified time since.
func (c *CopyConditions) SetUnmodified(modTime time.Time) error {
if modTime.IsZero() {
return ErrInvalidArgument("Modified since cannot be empty.")
}
c.conditions = append(c.conditions, copyCondition{
key: "x-amz-copy-source-if-unmodified-since",
value: modTime.Format(http.TimeFormat),
})
return nil
}
// SetModified - set modified time since.
func (c *CopyConditions) SetModified(modTime time.Time) error {
if modTime.IsZero() {
return ErrInvalidArgument("Modified since cannot be empty.")
}
c.conditions = append(c.conditions, copyCondition{
key: "x-amz-copy-source-if-modified-since",
value: modTime.Format(http.TimeFormat),
})
return nil
}

View File

@ -70,7 +70,13 @@ func (c Core) ListMultipartUploads(bucket, prefix, keyMarker, uploadIDMarker, de
// PutObjectPart - Upload an object part. // PutObjectPart - Upload an object part.
func (c Core) PutObjectPart(bucket, object, uploadID string, partID int, size int64, data io.Reader, md5Sum, sha256Sum []byte) (ObjectPart, error) { func (c Core) PutObjectPart(bucket, object, uploadID string, partID int, size int64, data io.Reader, md5Sum, sha256Sum []byte) (ObjectPart, error) {
return c.uploadPart(bucket, object, uploadID, data, partID, md5Sum, sha256Sum, size) return c.PutObjectPartWithMetadata(bucket, object, uploadID, partID, size, data, md5Sum, sha256Sum, nil)
}
// PutObjectPartWithMetadata - upload an object part with additional request metadata.
func (c Core) PutObjectPartWithMetadata(bucket, object, uploadID string, partID int,
size int64, data io.Reader, md5Sum, sha256Sum []byte, metadata map[string][]string) (ObjectPart, error) {
return c.uploadPart(bucket, object, uploadID, data, partID, md5Sum, sha256Sum, size, metadata)
} }
// ListObjectParts - List uploaded parts of an incomplete upload.x // ListObjectParts - List uploaded parts of an incomplete upload.x
@ -80,7 +86,9 @@ func (c Core) ListObjectParts(bucket, object, uploadID string, partNumberMarker
// CompleteMultipartUpload - Concatenate uploaded parts and commit to an object. // CompleteMultipartUpload - Concatenate uploaded parts and commit to an object.
func (c Core) CompleteMultipartUpload(bucket, object, uploadID string, parts []CompletePart) error { func (c Core) CompleteMultipartUpload(bucket, object, uploadID string, parts []CompletePart) error {
_, err := c.completeMultipartUpload(bucket, object, uploadID, completeMultipartUpload{Parts: parts}) _, err := c.completeMultipartUpload(bucket, object, uploadID, completeMultipartUpload{
Parts: parts,
})
return err return err
} }

View File

@ -18,14 +18,15 @@ package minio
import ( import (
"bytes" "bytes"
"crypto/md5"
"io" "io"
"math/rand" "log"
"os" "os"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"crypto/md5"
"math/rand"
) )
const ( const (
@ -35,6 +36,33 @@ const (
enableSecurity = "ENABLE_HTTPS" enableSecurity = "ENABLE_HTTPS"
) )
// Minimum part size
const MinPartSize = 1024 * 1024 * 64
const letterBytes = "abcdefghijklmnopqrstuvwxyz01234569"
const (
letterIdxBits = 6 // 6 bits to represent a letter index
letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits
)
// randString generates random names and prepends them with a known prefix.
func randString(n int, src rand.Source, prefix string) string {
b := make([]byte, n)
// A rand.Int63() generates 63 random bits, enough for letterIdxMax letters!
for i, cache, remain := n-1, src.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = src.Int63(), letterIdxMax
}
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
b[i] = letterBytes[idx]
i--
}
cache >>= letterIdxBits
remain--
}
return prefix + string(b[0:30-len(prefix)])
}
// Tests for Core GetObject() function. // Tests for Core GetObject() function.
func TestGetObjectCore(t *testing.T) { func TestGetObjectCore(t *testing.T) {
if testing.Short() { if testing.Short() {
@ -209,6 +237,76 @@ func TestGetObjectCore(t *testing.T) {
} }
} }
// Tests GetObject to return Content-Encoding properly set
// and overrides any auto decoding.
func TestGetObjectContentEncoding(t *testing.T) {
if testing.Short() {
t.Skip("skipping functional tests for the short runs")
}
// Seed random based on current time.
rand.Seed(time.Now().Unix())
// Instantiate new minio core client object.
c, err := NewCore(
os.Getenv(serverEndpoint),
os.Getenv(accessKey),
os.Getenv(secretKey),
mustParseBool(os.Getenv(enableSecurity)),
)
if err != nil {
t.Fatal("Error:", err)
}
// Enable tracing, write to stderr.
// c.TraceOn(os.Stderr)
// Set user agent.
c.SetAppInfo("Minio-go-FunctionalTest", "0.1.0")
// Generate a new random bucket name.
bucketName := randString(60, rand.NewSource(time.Now().UnixNano()), "minio-go-test")
// Make a new bucket.
err = c.MakeBucket(bucketName, "us-east-1")
if err != nil {
t.Fatal("Error:", err, bucketName)
}
// Generate data more than 32K
buf := bytes.Repeat([]byte("3"), rand.Intn(1<<20)+32*1024)
m := make(map[string][]string)
m["Content-Encoding"] = []string{"gzip"}
// Save the data
objectName := randString(60, rand.NewSource(time.Now().UnixNano()), "")
n, err := c.Client.PutObjectWithMetadata(bucketName, objectName, bytes.NewReader(buf), m, nil)
if err != nil {
t.Fatal("Error:", err, bucketName, objectName)
}
if n != int64(len(buf)) {
t.Fatalf("Error: number of bytes does not match, want %v, got %v\n", len(buf), n)
}
reqHeaders := NewGetReqHeaders()
rwc, objInfo, err := c.GetObject(bucketName, objectName, reqHeaders)
if err != nil {
t.Fatalf("Error: %v", err)
}
rwc.Close()
if objInfo.Size <= 0 {
t.Fatalf("Unexpected size of the object %v, expected %v", objInfo.Size, n)
}
value, ok := objInfo.Metadata["Content-Encoding"]
if !ok {
t.Fatalf("Expected Content-Encoding metadata to be set.")
}
if value[0] != "gzip" {
t.Fatalf("Unexpected content-encoding found, want gzip, got %v", value)
}
}
// Tests get bucket policy core API. // Tests get bucket policy core API.
func TestGetBucketPolicy(t *testing.T) { func TestGetBucketPolicy(t *testing.T) {
if testing.Short() { if testing.Short() {
@ -373,3 +471,48 @@ func TestCorePutObject(t *testing.T) {
t.Fatal("Error:", err) t.Fatal("Error:", err)
} }
} }
func TestCoreGetObjectMetadata(t *testing.T) {
if testing.Short() {
t.Skip("skipping functional tests for the short runs")
}
core, err := NewCore(
os.Getenv(serverEndpoint),
os.Getenv(accessKey),
os.Getenv(secretKey),
mustParseBool(os.Getenv(enableSecurity)))
if err != nil {
log.Fatalln(err)
}
// Generate a new random bucket name.
bucketName := randString(60, rand.NewSource(time.Now().UnixNano()), "minio-go-test")
// Make a new bucket.
err = core.MakeBucket(bucketName, "us-east-1")
if err != nil {
t.Fatal("Error:", err, bucketName)
}
metadata := map[string][]string{
"X-Amz-Meta-Key-1": {"Val-1"},
}
_, err = core.PutObject(bucketName, "my-objectname", 5,
bytes.NewReader([]byte("hello")), nil, nil, metadata)
if err != nil {
log.Fatalln(err)
}
reader, objInfo, err := core.GetObject(bucketName, "my-objectname",
RequestHeaders{})
if err != nil {
log.Fatalln(err)
}
defer reader.Close()
if objInfo.Metadata.Get("X-Amz-Meta-Key-1") != "Val-1" {
log.Fatalln("Expected metadata to be available but wasn't")
}
}

View File

@ -50,17 +50,21 @@ func main() {
} }
``` ```
| Bucket operations |Object operations | Encrypted Object operations | Presigned operations | Bucket Policy/Notification Operations | Client custom settings | | Bucket operations | Object operations | Encrypted Object operations | Presigned operations | Bucket Policy/Notification Operations | Client custom settings |
|:---|:---|:---|:---|:---|:---| | :--- | :--- | :--- | :--- | :--- | :--- |
|[`MakeBucket`](#MakeBucket) |[`GetObject`](#GetObject) | [`NewSymmetricKey`](#NewSymmetricKey) | [`PresignedGetObject`](#PresignedGetObject) |[`SetBucketPolicy`](#SetBucketPolicy) | [`SetAppInfo`](#SetAppInfo) | | [`MakeBucket`](#MakeBucket) | [`GetObject`](#GetObject) | [`NewSymmetricKey`](#NewSymmetricKey) | [`PresignedGetObject`](#PresignedGetObject) | [`SetBucketPolicy`](#SetBucketPolicy) | [`SetAppInfo`](#SetAppInfo) |
|[`ListBuckets`](#ListBuckets) |[`PutObject`](#PutObject) | [`NewAsymmetricKey`](#NewAsymmetricKey) |[`PresignedPutObject`](#PresignedPutObject) | [`GetBucketPolicy`](#GetBucketPolicy) | [`SetCustomTransport`](#SetCustomTransport) | | [`ListBuckets`](#ListBuckets) | [`PutObject`](#PutObject) | [`NewAsymmetricKey`](#NewAsymmetricKey) | [`PresignedPutObject`](#PresignedPutObject) | [`GetBucketPolicy`](#GetBucketPolicy) | [`SetCustomTransport`](#SetCustomTransport) |
|[`BucketExists`](#BucketExists) |[`CopyObject`](#CopyObject) | [`GetEncryptedObject`](#GetEncryptedObject) |[`PresignedPostPolicy`](#PresignedPostPolicy) | [`ListBucketPolicies`](#ListBucketPolicies) | [`TraceOn`](#TraceOn) | | [`BucketExists`](#BucketExists) | [`CopyObject`](#CopyObject) | [`GetEncryptedObject`](#GetEncryptedObject) | [`PresignedPostPolicy`](#PresignedPostPolicy) | [`ListBucketPolicies`](#ListBucketPolicies) | [`TraceOn`](#TraceOn) |
| [`RemoveBucket`](#RemoveBucket) |[`StatObject`](#StatObject) | [`PutObjectStreaming`](#PutObjectStreaming) | | [`SetBucketNotification`](#SetBucketNotification) | [`TraceOff`](#TraceOff) | | [`RemoveBucket`](#RemoveBucket) | [`StatObject`](#StatObject) | [`PutObjectStreaming`](#PutObjectStreaming) | | [`SetBucketNotification`](#SetBucketNotification) | [`TraceOff`](#TraceOff) |
|[`ListObjects`](#ListObjects) |[`RemoveObject`](#RemoveObject) | [`PutEncryptedObject`](#PutEncryptedObject) | | [`GetBucketNotification`](#GetBucketNotification) | [`SetS3TransferAccelerate`](#SetS3TransferAccelerate) | | [`ListObjects`](#ListObjects) | [`RemoveObject`](#RemoveObject) | [`PutEncryptedObject`](#PutEncryptedObject) | | [`GetBucketNotification`](#GetBucketNotification) | [`SetS3TransferAccelerate`](#SetS3TransferAccelerate) |
|[`ListObjectsV2`](#ListObjectsV2) | [`RemoveObjects`](#RemoveObjects) | | | [`RemoveAllBucketNotification`](#RemoveAllBucketNotification) | | [`ListObjectsV2`](#ListObjectsV2) | [`RemoveObjects`](#RemoveObjects) | [`NewSSEInfo`](#NewSSEInfo) | | [`RemoveAllBucketNotification`](#RemoveAllBucketNotification) | |
|[`ListIncompleteUploads`](#ListIncompleteUploads) | [`RemoveIncompleteUpload`](#RemoveIncompleteUpload) | | | [`ListenBucketNotification`](#ListenBucketNotification) | | [`ListIncompleteUploads`](#ListIncompleteUploads) | [`RemoveIncompleteUpload`](#RemoveIncompleteUpload) | | | [`ListenBucketNotification`](#ListenBucketNotification) | |
| | [`FPutObject`](#FPutObject) | | | | | | [`FPutObject`](#FPutObject) | | | | |
| | [`FGetObject`](#FGetObject) | | | | | | [`FGetObject`](#FGetObject) | | | | |
| | [`ComposeObject`](#ComposeObject) | | | | |
| | [`NewSourceInfo`](#NewSourceInfo) | | | | |
| | [`NewDestinationInfo`](#NewDestinationInfo) | | | | |
## 1. Constructor ## 1. Constructor
<a name="Minio"></a> <a name="Minio"></a>
@ -502,9 +506,11 @@ if err != nil {
<a name="CopyObject"></a> <a name="CopyObject"></a>
### CopyObject(bucketName, objectName, objectSource string, conditions CopyConditions) error ### CopyObject(dst DestinationInfo, src SourceInfo) error
Copy a source object into a new object with the provided name in the provided bucket. Create or replace an object through server-side copying of an existing object. It supports conditional copying, copying a part of an object and server-side encryption of destination and decryption of source. See the `SourceInfo` and `DestinationInfo` types for further details.
To copy multiple source objects into a single destination object see the `ComposeObject` API.
__Parameters__ __Parameters__
@ -512,50 +518,169 @@ __Parameters__
|Param |Type |Description | |Param |Type |Description |
|:---|:---| :---| |:---|:---| :---|
|`bucketName` | _string_ |Name of the bucket | |`dst` | _DestinationInfo_ |Argument describing the destination object |
|`objectName` | _string_ |Name of the object | |`src` | _SourceInfo_ |Argument describing the source object |
|`objectSource` | _string_ |Name of the source object |
|`conditions` | _CopyConditions_ |Collection of supported CopyObject conditions. [`x-amz-copy-source`, `x-amz-copy-source-if-match`, `x-amz-copy-source-if-none-match`, `x-amz-copy-source-if-unmodified-since`, `x-amz-copy-source-if-modified-since`]|
__Example__ __Example__
```go ```go
// Use-case-1 // Use-case 1: Simple copy object with no conditions, etc
// To copy an existing object to a new object with _no_ copy conditions. // Source object
copyConds := minio.CopyConditions{} src := minio.NewSourceInfo("my-sourcebucketname", "my-sourceobjectname", nil)
err := minioClient.CopyObject("mybucket", "myobject", "my-sourcebucketname/my-sourceobjectname", copyConds)
// Destination object
dst, err := minio.NewDestinationInfo("my-bucketname", "my-objectname", nil, nil)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
// Use-case-2 // Copy object call
// To copy an existing object to a new object with the following copy conditions err = s3Client.CopyObject(dst, src)
if err != nil {
fmt.Println(err)
return
}
// Use-case 2: Copy object with copy-conditions, and copying only part of the source object.
// 1. that matches a given ETag // 1. that matches a given ETag
// 2. and modified after 1st April 2014 // 2. and modified after 1st April 2014
// 3. but unmodified since 23rd April 2014 // 3. but unmodified since 23rd April 2014
// 4. copy only first 1MiB of object.
// Initialize empty copy conditions. // Source object
var copyConds = minio.CopyConditions{} src := minio.NewSourceInfo("my-sourcebucketname", "my-sourceobjectname", nil)
// copy object that matches the given ETag. // Set matching ETag condition, copy object which matches the following ETag.
copyConds.SetMatchETag("31624deb84149d2f8ef9c385918b653a") src.SetMatchETagCond("31624deb84149d2f8ef9c385918b653a")
// and modified after 1st April 2014 // Set modified condition, copy object modified since 2014 April 1.
copyConds.SetModified(time.Date(2014, time.April, 1, 0, 0, 0, 0, time.UTC)) src.SetModifiedSinceCond(time.Date(2014, time.April, 1, 0, 0, 0, 0, time.UTC))
// but unmodified since 23rd April 2014 // Set unmodified condition, copy object unmodified since 2014 April 23.
copyConds.SetUnmodified(time.Date(2014, time.April, 23, 0, 0, 0, 0, time.UTC)) src.SetUnmodifiedSinceCond(time.Date(2014, time.April, 23, 0, 0, 0, 0, time.UTC))
err := minioClient.CopyObject("mybucket", "myobject", "my-sourcebucketname/my-sourceobjectname", copyConds) // Set copy-range of only first 1MiB of file.
src.SetRange(0, 1024*1024-1)
// Destination object
dst, err := minio.NewDestinationInfo("my-bucketname", "my-objectname", nil, nil)
if err != nil {
fmt.Println(err)
return
}
// Copy object call
err = s3Client.CopyObject(dst, src)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
``` ```
<a name="ComposeObject"></a>
### ComposeObject(dst DestinationInfo, srcs []SourceInfo) error
Create an object by concatenating a list of source objects using
server-side copying.
__Parameters__
|Param |Type |Description |
|:---|:---|:---|
|`dst` | _minio.DestinationInfo_ |Struct with info about the object to be created. |
|`srcs` | _[]minio.SourceInfo_ |Slice of struct with info about source objects to be concatenated in order. |
__Example__
```go
// Prepare source decryption key (here we assume same key to
// decrypt all source objects.)
decKey := minio.NewSSEInfo([]byte{1, 2, 3}, "")
// Source objects to concatenate. We also specify decryption
// key for each
src1 := minio.NewSourceInfo("bucket1", "object1", decKey)
src1.SetMatchETag("31624deb84149d2f8ef9c385918b653a")
src2 := minio.NewSourceInfo("bucket2", "object2", decKey)
src2.SetMatchETag("f8ef9c385918b653a31624deb84149d2")
src3 := minio.NewSourceInfo("bucket3", "object3", decKey)
src3.SetMatchETag("5918b653a31624deb84149d2f8ef9c38")
// Create slice of sources.
srcs := []minio.SourceInfo{src1, src2, src3}
// Prepare destination encryption key
encKey := minio.NewSSEInfo([]byte{8, 9, 0}, "")
// Create destination info
dst := minio.NewDestinationInfo("bucket", "object", encKey, nil)
err = s3Client.ComposeObject(dst, srcs)
if err != nil {
log.Println(err)
return
}
log.Println("Composed object successfully.")
```
<a name="NewSourceInfo"></a>
### NewSourceInfo(bucket, object string, decryptSSEC *SSEInfo) SourceInfo
Construct a `SourceInfo` object that can be used as the source for server-side copying operations like `CopyObject` and `ComposeObject`. This object can be used to set copy-conditions on the source.
__Parameters__
| Param | Type | Description |
| :--- | :--- | :--- |
| `bucket` | _string_ | Name of the source bucket |
| `object` | _string_ | Name of the source object |
| `decryptSSEC` | _*minio.SSEInfo_ | Decryption info for the source object (`nil` without encryption) |
__Example__
``` go
// No decryption parameter.
src := NewSourceInfo("bucket", "object", nil)
// With decryption parameter.
decKey := NewSSEKey([]byte{1,2,3}, "")
src := NewSourceInfo("bucket", "object", decKey)
```
<a name="NewDestinationInfo"></a>
### NewDestinationInfo(bucket, object string, encryptSSEC *SSEInfo, userMeta map[string]string) (DestinationInfo, error)
Construct a `DestinationInfo` object that can be used as the destination object for server-side copying operations like `CopyObject` and `ComposeObject`.
__Parameters__
| Param | Type | Description |
| :--- | :--- | :--- |
| `bucket` | _string_ | Name of the destination bucket |
| `object` | _string_ | Name of the destination object |
| `encryptSSEC` | _*minio.SSEInfo_ | Encryption info for the source object (`nil` without encryption) |
| `userMeta` | _map[string]string_ | User metadata to be set on the destination. If nil, with only one source, user-metadata is copied from source. |
__Example__
``` go
// No encryption parameter.
dst, err := NewDestinationInfo("bucket", "object", nil, nil)
// With encryption parameter.
encKey := NewSSEKey([]byte{1,2,3}, "")
dst, err := NewDecryptionInfo("bucket", "object", encKey, nil)
```
<a name="FPutObject"></a> <a name="FPutObject"></a>
### FPutObject(bucketName, objectName, filePath, contentType string) (length int64, err error) ### FPutObject(bucketName, objectName, filePath, contentType string) (length int64, err error)
@ -881,6 +1006,26 @@ if err != nil {
} }
``` ```
<a name="NewSSEInfo"></a>
### NewSSEInfo(key []byte, algo string) SSEInfo
Create a key object for use as encryption or decryption parameter in operations involving server-side-encryption with customer provided key (SSE-C).
__Parameters__
| Param | Type | Description |
| :--- | :--- | :--- |
| `key` | _[]byte_ | Byte-slice of the raw, un-encoded binary key |
| `algo` | _string_ | Algorithm to use in encryption or decryption with the given key. Can be empty (defaults to `AES256`) |
__Example__
``` go
// Key for use in encryption/decryption
keyInfo := NewSSEInfo([]byte{1,2,3}, "")
```
## 5. Presigned operations ## 5. Presigned operations
<a name="PresignedGetObject"></a> <a name="PresignedGetObject"></a>

View File

@ -0,0 +1,77 @@
// +build ignore
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2016 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"log"
minio "github.com/minio/minio-go"
)
func main() {
// Note: YOUR-ACCESSKEYID, YOUR-SECRETACCESSKEY, my-testfile, my-bucketname and
// my-objectname are dummy values, please replace them with original values.
// Requests are always secure (HTTPS) by default. Set secure=false to enable insecure (HTTP) access.
// This boolean value is the last argument for New().
// New returns an Amazon S3 compatible client object. API compatibility (v2 or v4) is automatically
// determined based on the Endpoint value.
s3Client, err := minio.New("s3.amazonaws.com", "YOUR-ACCESSKEYID", "YOUR-SECRETACCESSKEY", true)
if err != nil {
log.Fatalln(err)
}
// Enable trace.
// s3Client.TraceOn(os.Stderr)
// Prepare source decryption key (here we assume same key to
// decrypt all source objects.)
decKey := minio.NewSSEInfo([]byte{1, 2, 3}, "")
// Source objects to concatenate. We also specify decryption
// key for each
src1 := minio.NewSourceInfo("bucket1", "object1", &decKey)
src1.SetMatchETagCond("31624deb84149d2f8ef9c385918b653a")
src2 := minio.NewSourceInfo("bucket2", "object2", &decKey)
src2.SetMatchETagCond("f8ef9c385918b653a31624deb84149d2")
src3 := minio.NewSourceInfo("bucket3", "object3", &decKey)
src3.SetMatchETagCond("5918b653a31624deb84149d2f8ef9c38")
// Create slice of sources.
srcs := []minio.SourceInfo{src1, src2, src3}
// Prepare destination encryption key
encKey := minio.NewSSEInfo([]byte{8, 9, 0}, "")
// Create destination info
dst, err := minio.NewDestinationInfo("bucket", "object", &encKey, nil)
if err != nil {
log.Fatalln(err)
}
err = s3Client.ComposeObject(dst, srcs)
if err != nil {
log.Fatalln(err)
}
log.Println("Composed object successfully.")
}

View File

@ -42,24 +42,31 @@ func main() {
// Enable trace. // Enable trace.
// s3Client.TraceOn(os.Stderr) // s3Client.TraceOn(os.Stderr)
// Source object
src := minio.NewSourceInfo("my-sourcebucketname", "my-sourceobjectname", nil)
// All following conditions are allowed and can be combined together. // All following conditions are allowed and can be combined together.
// Set copy conditions.
var copyConds = minio.CopyConditions{}
// Set modified condition, copy object modified since 2014 April. // Set modified condition, copy object modified since 2014 April.
copyConds.SetModified(time.Date(2014, time.April, 0, 0, 0, 0, 0, time.UTC)) src.SetModifiedSinceCond(time.Date(2014, time.April, 0, 0, 0, 0, 0, time.UTC))
// Set unmodified condition, copy object unmodified since 2014 April. // Set unmodified condition, copy object unmodified since 2014 April.
// copyConds.SetUnmodified(time.Date(2014, time.April, 0, 0, 0, 0, 0, time.UTC)) // src.SetUnmodifiedSinceCond(time.Date(2014, time.April, 0, 0, 0, 0, 0, time.UTC))
// Set matching ETag condition, copy object which matches the following ETag. // Set matching ETag condition, copy object which matches the following ETag.
// copyConds.SetMatchETag("31624deb84149d2f8ef9c385918b653a") // src.SetMatchETagCond("31624deb84149d2f8ef9c385918b653a")
// Set matching ETag except condition, copy object which does not match the following ETag. // Set matching ETag except condition, copy object which does not match the following ETag.
// copyConds.SetMatchETagExcept("31624deb84149d2f8ef9c385918b653a") // src.SetMatchETagExceptCond("31624deb84149d2f8ef9c385918b653a")
// Destination object
dst, err := minio.NewDestinationInfo("my-bucketname", "my-objectname", nil, nil)
if err != nil {
log.Fatalln(err)
}
// Initiate copy object. // Initiate copy object.
err = s3Client.CopyObject("my-bucketname", "my-objectname", "/my-sourcebucketname/my-sourceobjectname", copyConds) err = s3Client.CopyObject(dst, src)
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err)
} }

View File

@ -55,7 +55,9 @@ func main() {
progress := pb.New64(objectInfo.Size) progress := pb.New64(objectInfo.Size)
progress.Start() progress.Start()
n, err := s3Client.PutObjectWithProgress("my-bucketname", "my-objectname-progress", reader, "application/octet-stream", progress) n, err := s3Client.PutObjectWithProgress("my-bucketname", "my-objectname-progress", reader, map[string][]string{
"Content-Type": []string{"application/octet-stream"},
}, progress)
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err)
} }

View File

@ -40,7 +40,7 @@ func main() {
} }
// Enable S3 transfer accelerate endpoint. // Enable S3 transfer accelerate endpoint.
s3Client.S3TransferAccelerate("s3-accelerate.amazonaws.com") s3Client.SetS3TransferAccelerate("s3-accelerate.amazonaws.com")
object, err := os.Open("my-testfile") object, err := os.Open("my-testfile")
if err != nil { if err != nil {

File diff suppressed because it is too large Load Diff

View File

@ -99,7 +99,7 @@ func prepareStreamingRequest(req *http.Request, sessionToken string, dataLen int
if sessionToken != "" { if sessionToken != "" {
req.Header.Set("X-Amz-Security-Token", sessionToken) req.Header.Set("X-Amz-Security-Token", sessionToken)
} }
req.Header.Set("Content-Encoding", streamingEncoding) req.Header.Add("Content-Encoding", streamingEncoding)
req.Header.Set("X-Amz-Date", timestamp.Format(iso8601DateFormat)) req.Header.Set("X-Amz-Date", timestamp.Format(iso8601DateFormat))
// Set content length with streaming signature for each chunk included. // Set content length with streaming signature for each chunk included.
@ -254,7 +254,18 @@ func (s *StreamingReader) Read(buf []byte) (int, error) {
s.chunkBufLen = 0 s.chunkBufLen = 0
for { for {
n1, err := s.baseReadCloser.Read(s.chunkBuf[s.chunkBufLen:]) n1, err := s.baseReadCloser.Read(s.chunkBuf[s.chunkBufLen:])
if err == nil || err == io.ErrUnexpectedEOF { // Usually we validate `err` first, but in this case
// we are validating n > 0 for the following reasons.
//
// 1. n > 0, err is one of io.EOF, nil (near end of stream)
// A Reader returning a non-zero number of bytes at the end
// of the input stream may return either err == EOF or err == nil
//
// 2. n == 0, err is io.EOF (actual end of stream)
//
// Callers should always process the n > 0 bytes returned
// before considering the error err.
if n1 > 0 {
s.chunkBufLen += n1 s.chunkBufLen += n1
s.bytesRead += int64(n1) s.bytesRead += int64(n1)
@ -265,25 +276,26 @@ func (s *StreamingReader) Read(buf []byte) (int, error) {
s.signChunk(s.chunkBufLen) s.signChunk(s.chunkBufLen)
break break
} }
}
if err != nil {
if err == io.EOF {
// No more data left in baseReader - last chunk.
// Done reading the last chunk from baseReader.
s.done = true
} else if err == io.EOF { // bytes read from baseReader different than
// No more data left in baseReader - last chunk. // content length provided.
// Done reading the last chunk from baseReader. if s.bytesRead != s.contentLen {
s.done = true return 0, io.ErrUnexpectedEOF
}
// bytes read from baseReader different than // Sign the chunk and write it to s.buf.
// content length provided. s.signChunk(0)
if s.bytesRead != s.contentLen { break
return 0, io.ErrUnexpectedEOF
} }
// Sign the chunk and write it to s.buf.
s.signChunk(0)
break
} else {
return 0, err return 0, err
} }
} }
} }
return s.buf.Read(buf) return s.buf.Read(buf)

View File

@ -1,60 +0,0 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2015 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package minio
import (
"io/ioutil"
"os"
"sync"
)
// tempFile - temporary file container.
type tempFile struct {
*os.File
mutex *sync.Mutex
}
// newTempFile returns a new temporary file, once closed it automatically deletes itself.
func newTempFile(prefix string) (*tempFile, error) {
// use platform specific temp directory.
file, err := ioutil.TempFile(os.TempDir(), prefix)
if err != nil {
return nil, err
}
return &tempFile{
File: file,
mutex: &sync.Mutex{},
}, nil
}
// Close - closer wrapper to close and remove temporary file.
func (t *tempFile) Close() error {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.File != nil {
// Close the file.
if err := t.File.Close(); err != nil {
return err
}
// Remove file.
if err := os.Remove(t.File.Name()); err != nil {
return err
}
t.File = nil
}
return nil
}

View File

@ -64,11 +64,11 @@ func encodeResponse(response interface{}) []byte {
return bytesBuffer.Bytes() return bytesBuffer.Bytes()
} }
// Convert string to bool and always return true if any error // Convert string to bool and always return false if any error
func mustParseBool(str string) bool { func mustParseBool(str string) bool {
b, err := strconv.ParseBool(str) b, err := strconv.ParseBool(str)
if err != nil { if err != nil {
return true return false
} }
return b return b
} }

View File

@ -0,0 +1,48 @@
// +build go1.7 go1.8
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage
* (C) 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package minio
import (
"net"
"net/http"
"time"
)
// This default transport is similar to http.DefaultTransport
// but with additional DisableCompression:
var defaultMinioTransport http.RoundTripper = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
// Set this value so that the underlying transport round-tripper
// doesn't try to auto decode the body of objects with
// content-encoding set to `gzip`.
//
// Refer:
// https://golang.org/src/net/http/transport.go?h=roundTrip#L1843
DisableCompression: true,
}

View File

@ -0,0 +1,39 @@
// +build go1.5,!go1.6,!go1.7,!go1.8
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage
* (C) 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package minio
import (
"net/http"
"time"
)
// This default transport is similar to http.DefaultTransport
// but with additional DisableCompression:
var defaultMinioTransport http.RoundTripper = &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: 10 * time.Second,
// Set this value so that the underlying transport round-tripper
// doesn't try to auto decode the body of objects with
// content-encoding set to `gzip`.
//
// Refer:
// https://golang.org/src/net/http/transport.go?h=roundTrip#L1843
DisableCompression: true,
}

View File

@ -0,0 +1,40 @@
// +build go1.6,!go1.7,!go1.8
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage
* (C) 2017 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package minio
import (
"net/http"
"time"
)
// This default transport is similar to http.DefaultTransport
// but with additional DisableCompression:
var defaultMinioTransport http.RoundTripper = &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
// Set this value so that the underlying transport round-tripper
// doesn't try to auto decode the body of objects with
// content-encoding set to `gzip`.
//
// Refer:
// https://golang.org/src/net/http/transport.go?h=roundTrip#L1843
DisableCompression: true,
}

View File

@ -122,7 +122,7 @@ func isValidEndpointURL(endpointURL url.URL) error {
if endpointURL.Path != "/" && endpointURL.Path != "" { if endpointURL.Path != "/" && endpointURL.Path != "" {
return ErrInvalidArgument("Endpoint url cannot have fully qualified paths.") return ErrInvalidArgument("Endpoint url cannot have fully qualified paths.")
} }
if strings.Contains(endpointURL.Host, ".amazonaws.com") { if strings.Contains(endpointURL.Host, ".s3.amazonaws.com") {
if !s3utils.IsAmazonEndpoint(endpointURL) { if !s3utils.IsAmazonEndpoint(endpointURL) {
return ErrInvalidArgument("Amazon S3 endpoint should be 's3.amazonaws.com'.") return ErrInvalidArgument("Amazon S3 endpoint should be 's3.amazonaws.com'.")
} }

View File

@ -84,9 +84,9 @@ func TestGetEndpointURL(t *testing.T) {
{"s3.cn-north-1.amazonaws.com.cn", false, "http://s3.cn-north-1.amazonaws.com.cn", nil, true}, {"s3.cn-north-1.amazonaws.com.cn", false, "http://s3.cn-north-1.amazonaws.com.cn", nil, true},
{"192.168.1.1:9000", false, "http://192.168.1.1:9000", nil, true}, {"192.168.1.1:9000", false, "http://192.168.1.1:9000", nil, true},
{"192.168.1.1:9000", true, "https://192.168.1.1:9000", nil, true}, {"192.168.1.1:9000", true, "https://192.168.1.1:9000", nil, true},
{"s3.amazonaws.com:443", true, "https://s3.amazonaws.com:443", nil, true},
{"13333.123123.-", true, "", ErrInvalidArgument(fmt.Sprintf("Endpoint: %s does not follow ip address or domain name standards.", "13333.123123.-")), false}, {"13333.123123.-", true, "", ErrInvalidArgument(fmt.Sprintf("Endpoint: %s does not follow ip address or domain name standards.", "13333.123123.-")), false},
{"13333.123123.-", true, "", ErrInvalidArgument(fmt.Sprintf("Endpoint: %s does not follow ip address or domain name standards.", "13333.123123.-")), false}, {"13333.123123.-", true, "", ErrInvalidArgument(fmt.Sprintf("Endpoint: %s does not follow ip address or domain name standards.", "13333.123123.-")), false},
{"s3.amazonaws.com:443", true, "", ErrInvalidArgument("Amazon S3 endpoint should be 's3.amazonaws.com'."), false},
{"storage.googleapis.com:4000", true, "", ErrInvalidArgument("Google Cloud Storage endpoint should be 'storage.googleapis.com'."), false}, {"storage.googleapis.com:4000", true, "", ErrInvalidArgument("Google Cloud Storage endpoint should be 'storage.googleapis.com'."), false},
{"s3.aamzza.-", true, "", ErrInvalidArgument(fmt.Sprintf("Endpoint: %s does not follow ip address or domain name standards.", "s3.aamzza.-")), false}, {"s3.aamzza.-", true, "", ErrInvalidArgument(fmt.Sprintf("Endpoint: %s does not follow ip address or domain name standards.", "s3.aamzza.-")), false},
{"", true, "", ErrInvalidArgument("Endpoint: does not follow ip address or domain name standards."), false}, {"", true, "", ErrInvalidArgument("Endpoint: does not follow ip address or domain name standards."), false},
@ -132,10 +132,11 @@ func TestIsValidEndpointURL(t *testing.T) {
{"https://s3-fips-us-gov-west-1.amazonaws.com", nil, true}, {"https://s3-fips-us-gov-west-1.amazonaws.com", nil, true},
{"https://s3.amazonaws.com/", nil, true}, {"https://s3.amazonaws.com/", nil, true},
{"https://storage.googleapis.com/", nil, true}, {"https://storage.googleapis.com/", nil, true},
{"https://z3.amazonaws.com", nil, true},
{"https://mybalancer.us-east-1.elb.amazonaws.com", nil, true},
{"192.168.1.1", ErrInvalidArgument("Endpoint url cannot have fully qualified paths."), false}, {"192.168.1.1", ErrInvalidArgument("Endpoint url cannot have fully qualified paths."), false},
{"https://amazon.googleapis.com/", ErrInvalidArgument("Google Cloud Storage endpoint should be 'storage.googleapis.com'."), false}, {"https://amazon.googleapis.com/", ErrInvalidArgument("Google Cloud Storage endpoint should be 'storage.googleapis.com'."), false},
{"https://storage.googleapis.com/bucket/", ErrInvalidArgument("Endpoint url cannot have fully qualified paths."), false}, {"https://storage.googleapis.com/bucket/", ErrInvalidArgument("Endpoint url cannot have fully qualified paths."), false},
{"https://z3.amazonaws.com", ErrInvalidArgument("Amazon S3 endpoint should be 's3.amazonaws.com'."), false},
{"https://s3.amazonaws.com/bucket/object", ErrInvalidArgument("Endpoint url cannot have fully qualified paths."), false}, {"https://s3.amazonaws.com/bucket/object", ErrInvalidArgument("Endpoint url cannot have fully qualified paths."), false},
} }

View File

@ -75,6 +75,7 @@ func makeConnection(t *testing.T) (*swift.Connection, func()) {
ConnectionChannelTimeout := os.Getenv("SWIFT_CONNECTION_CHANNEL_TIMEOUT") ConnectionChannelTimeout := os.Getenv("SWIFT_CONNECTION_CHANNEL_TIMEOUT")
DataChannelTimeout := os.Getenv("SWIFT_DATA_CHANNEL_TIMEOUT") DataChannelTimeout := os.Getenv("SWIFT_DATA_CHANNEL_TIMEOUT")
internalServer := false
if UserName == "" || ApiKey == "" || AuthUrl == "" { if UserName == "" || ApiKey == "" || AuthUrl == "" {
srv, err = swifttest.NewSwiftServer("localhost") srv, err = swifttest.NewSwiftServer("localhost")
if err != nil && t != nil { if err != nil && t != nil {
@ -84,6 +85,7 @@ func makeConnection(t *testing.T) (*swift.Connection, func()) {
UserName = "swifttest" UserName = "swifttest"
ApiKey = "swifttest" ApiKey = "swifttest"
AuthUrl = srv.AuthURL AuthUrl = srv.AuthURL
internalServer = true
} }
transport := &http.Transport{ transport := &http.Transport{
@ -105,6 +107,16 @@ func makeConnection(t *testing.T) (*swift.Connection, func()) {
EndpointType: swift.EndpointType(EndpointType), EndpointType: swift.EndpointType(EndpointType),
} }
if !internalServer {
if isV3Api() {
c.Tenant = os.Getenv("SWIFT_TENANT")
c.Domain = os.Getenv("SWIFT_API_DOMAIN")
} else {
c.Tenant = os.Getenv("SWIFT_TENANT")
c.TenantId = os.Getenv("SWIFT_TENANT_ID")
}
}
var timeout int64 var timeout int64
if ConnectionChannelTimeout != "" { if ConnectionChannelTimeout != "" {
timeout, err = strconv.ParseInt(ConnectionChannelTimeout, 10, 32) timeout, err = strconv.ParseInt(ConnectionChannelTimeout, 10, 32)
@ -304,14 +316,6 @@ func TestTransport(t *testing.T) {
c.Transport = tr c.Transport = tr
if isV3Api() {
c.Tenant = os.Getenv("SWIFT_TENANT")
c.Domain = os.Getenv("SWIFT_API_DOMAIN")
} else {
c.Tenant = os.Getenv("SWIFT_TENANT")
c.TenantId = os.Getenv("SWIFT_TENANT_ID")
}
err := c.Authenticate() err := c.Authenticate()
if err != nil { if err != nil {
t.Fatal("Auth failed", err) t.Fatal("Auth failed", err)
@ -329,9 +333,6 @@ func TestV1V2Authenticate(t *testing.T) {
c, rollback := makeConnection(t) c, rollback := makeConnection(t)
defer rollback() defer rollback()
c.Tenant = os.Getenv("SWIFT_TENANT")
c.TenantId = os.Getenv("SWIFT_TENANT_ID")
err := c.Authenticate() err := c.Authenticate()
if err != nil { if err != nil {
t.Fatal("Auth failed", err) t.Fatal("Auth failed", err)
@ -349,8 +350,10 @@ func TestV3AuthenticateWithDomainNameAndTenantId(t *testing.T) {
c, rollback := makeConnection(t) c, rollback := makeConnection(t)
defer rollback() defer rollback()
c.TenantId = os.Getenv("SWIFT_TENANT_ID") c.Tenant = ""
c.Domain = os.Getenv("SWIFT_API_DOMAIN") c.Domain = os.Getenv("SWIFT_API_DOMAIN")
c.TenantId = os.Getenv("SWIFT_TENANT_ID")
c.DomainId = ""
err := c.Authenticate() err := c.Authenticate()
if err != nil { if err != nil {
@ -388,6 +391,8 @@ func TestV3AuthenticateWithDomainIdAndTenantId(t *testing.T) {
c, rollback := makeConnection(t) c, rollback := makeConnection(t)
defer rollback() defer rollback()
c.Tenant = ""
c.Domain = ""
c.TenantId = os.Getenv("SWIFT_TENANT_ID") c.TenantId = os.Getenv("SWIFT_TENANT_ID")
c.DomainId = os.Getenv("SWIFT_API_DOMAIN_ID") c.DomainId = os.Getenv("SWIFT_API_DOMAIN_ID")
@ -410,6 +415,8 @@ func TestV3AuthenticateWithDomainNameAndTenantName(t *testing.T) {
c.Tenant = os.Getenv("SWIFT_TENANT") c.Tenant = os.Getenv("SWIFT_TENANT")
c.Domain = os.Getenv("SWIFT_API_DOMAIN") c.Domain = os.Getenv("SWIFT_API_DOMAIN")
c.TenantId = ""
c.DomainId = ""
err := c.Authenticate() err := c.Authenticate()
if err != nil { if err != nil {
@ -429,6 +436,8 @@ func TestV3AuthenticateWithDomainIdAndTenantName(t *testing.T) {
defer rollback() defer rollback()
c.Tenant = os.Getenv("SWIFT_TENANT") c.Tenant = os.Getenv("SWIFT_TENANT")
c.Domain = ""
c.TenantId = ""
c.DomainId = os.Getenv("SWIFT_API_DOMAIN_ID") c.DomainId = os.Getenv("SWIFT_API_DOMAIN_ID")
err := c.Authenticate() err := c.Authenticate()

View File

@ -15,6 +15,7 @@ func noErrors(at, depth int) error {
} }
return noErrors(at+1, depth) return noErrors(at+1, depth)
} }
func yesErrors(at, depth int) error { func yesErrors(at, depth int) error {
if at >= depth { if at >= depth {
return New("ye error") return New("ye error")
@ -22,8 +23,11 @@ func yesErrors(at, depth int) error {
return yesErrors(at+1, depth) return yesErrors(at+1, depth)
} }
// GlobalE is an exported global to store the result of benchmark results,
// preventing the compiler from optimising the benchmark functions away.
var GlobalE error
func BenchmarkErrors(b *testing.B) { func BenchmarkErrors(b *testing.B) {
var toperr error
type run struct { type run struct {
stack int stack int
std bool std bool
@ -53,7 +57,7 @@ func BenchmarkErrors(b *testing.B) {
err = f(0, r.stack) err = f(0, r.stack)
} }
b.StopTimer() b.StopTimer()
toperr = err GlobalE = err
}) })
} }
} }

View File

@ -196,7 +196,6 @@ func TestWithMessage(t *testing.T) {
t.Errorf("WithMessage(%v, %q): got: %q, want %q", tt.err, tt.message, got, tt.want) t.Errorf("WithMessage(%v, %q): got: %q, want %q", tt.err, tt.message, got, tt.want)
} }
} }
} }
// errors.New, etc values are not expected to be compared by value // errors.New, etc values are not expected to be compared by value

View File

@ -79,6 +79,14 @@ func (f Frame) Format(s fmt.State, verb rune) {
// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). // StackTrace is stack of Frames from innermost (newest) to outermost (oldest).
type StackTrace []Frame type StackTrace []Frame
// Format formats the stack of Frames according to the fmt.Formatter interface.
//
// %s lists source files for each Frame in the stack
// %v lists the source file and line number for each Frame in the stack
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+v Prints filename, function, and line number for each Frame in the stack.
func (st StackTrace) Format(s fmt.State, verb rune) { func (st StackTrace) Format(s fmt.State, verb rune) {
switch verb { switch verb {
case 'v': case 'v':

View File

@ -45,3 +45,10 @@ func main() {
Several convenience package level values are provided for cpu, memory, and block (contention) profiling. Several convenience package level values are provided for cpu, memory, and block (contention) profiling.
For more complex options, consult the [documentation](http://godoc.org/github.com/pkg/profile). For more complex options, consult the [documentation](http://godoc.org/github.com/pkg/profile).
contributing
------------
We welcome pull requests, bug fixes and issue reports.
Before proposing a change, please discuss it first by raising an issue.

View File

@ -31,7 +31,7 @@ func ExampleMemProfileRate() {
func ExampleProfilePath() { func ExampleProfilePath() {
// set the location that the profile will be written to // set the location that the profile will be written to
defer profile.Start(profile.ProfilePath(os.Getenv("HOME"))) defer profile.Start(profile.ProfilePath(os.Getenv("HOME"))).Stop()
} }
func ExampleNoShutdownHook() { func ExampleNoShutdownHook() {
@ -41,13 +41,15 @@ func ExampleNoShutdownHook() {
func ExampleStart_withFlags() { func ExampleStart_withFlags() {
// use the flags package to selectively enable profiling. // use the flags package to selectively enable profiling.
mode := flag.String("profile.mode", "", "enable profiling mode, one of [cpu, mem, block]") mode := flag.String("profile.mode", "", "enable profiling mode, one of [cpu, mem, mutex, block]")
flag.Parse() flag.Parse()
switch *mode { switch *mode {
case "cpu": case "cpu":
defer profile.Start(profile.CPUProfile).Stop() defer profile.Start(profile.CPUProfile).Stop()
case "mem": case "mem":
defer profile.Start(profile.MemProfile).Stop() defer profile.Start(profile.MemProfile).Stop()
case "mutex":
defer profile.Start(profile.MutexProfile).Stop()
case "block": case "block":
defer profile.Start(profile.BlockProfile).Stop() defer profile.Start(profile.BlockProfile).Stop()
default: default:

View File

@ -0,0 +1,13 @@
// +build go1.8
package profile
import "runtime"
func enableMutexProfile() {
runtime.SetMutexProfileFraction(1)
}
func disableMutexProfile() {
runtime.SetMutexProfileFraction(0)
}

View File

@ -0,0 +1,9 @@
// +build !go1.8
package profile
// mock mutex support for Go 1.7 and earlier.
func enableMutexProfile() {}
func disableMutexProfile() {}

View File

@ -16,11 +16,13 @@ import (
const ( const (
cpuMode = iota cpuMode = iota
memMode memMode
mutexMode
blockMode blockMode
traceMode traceMode
) )
type profile struct { // Profile represents an active profiling session.
type Profile struct {
// quiet suppresses informational messages during profiling. // quiet suppresses informational messages during profiling.
quiet bool quiet bool
@ -50,14 +52,14 @@ type profile struct {
// Programs with more sophisticated signal handling should set // Programs with more sophisticated signal handling should set
// this to true and ensure the Stop() function returned from Start() // this to true and ensure the Stop() function returned from Start()
// is called during shutdown. // is called during shutdown.
func NoShutdownHook(p *profile) { p.noShutdownHook = true } func NoShutdownHook(p *Profile) { p.noShutdownHook = true }
// Quiet suppresses informational messages during profiling. // Quiet suppresses informational messages during profiling.
func Quiet(p *profile) { p.quiet = true } func Quiet(p *Profile) { p.quiet = true }
// CPUProfile enables cpu profiling. // CPUProfile enables cpu profiling.
// It disables any previous profiling settings. // It disables any previous profiling settings.
func CPUProfile(p *profile) { p.mode = cpuMode } func CPUProfile(p *Profile) { p.mode = cpuMode }
// DefaultMemProfileRate is the default memory profiling rate. // DefaultMemProfileRate is the default memory profiling rate.
// See also http://golang.org/pkg/runtime/#pkg-variables // See also http://golang.org/pkg/runtime/#pkg-variables
@ -65,35 +67,44 @@ const DefaultMemProfileRate = 4096
// MemProfile enables memory profiling. // MemProfile enables memory profiling.
// It disables any previous profiling settings. // It disables any previous profiling settings.
func MemProfile(p *profile) { func MemProfile(p *Profile) {
p.memProfileRate = DefaultMemProfileRate p.memProfileRate = DefaultMemProfileRate
p.mode = memMode p.mode = memMode
} }
// MemProfileRate enables memory profiling at the preferred rate. // MemProfileRate enables memory profiling at the preferred rate.
// It disables any previous profiling settings. // It disables any previous profiling settings.
func MemProfileRate(rate int) func(*profile) { func MemProfileRate(rate int) func(*Profile) {
return func(p *profile) { return func(p *Profile) {
p.memProfileRate = rate p.memProfileRate = rate
p.mode = memMode p.mode = memMode
} }
} }
// MutexProfile enables mutex profiling.
// It disables any previous profiling settings.
//
// Mutex profiling is a no-op before go1.8.
func MutexProfile(p *Profile) { p.mode = mutexMode }
// BlockProfile enables block (contention) profiling. // BlockProfile enables block (contention) profiling.
// It disables any previous profiling settings. // It disables any previous profiling settings.
func BlockProfile(p *profile) { p.mode = blockMode } func BlockProfile(p *Profile) { p.mode = blockMode }
// Trace profile controls if execution tracing will be enabled. It disables any previous profiling settings.
func TraceProfile(p *Profile) { p.mode = traceMode }
// ProfilePath controls the base path where various profiling // ProfilePath controls the base path where various profiling
// files are written. If blank, the base path will be generated // files are written. If blank, the base path will be generated
// by ioutil.TempDir. // by ioutil.TempDir.
func ProfilePath(path string) func(*profile) { func ProfilePath(path string) func(*Profile) {
return func(p *profile) { return func(p *Profile) {
p.path = path p.path = path
} }
} }
// Stop stops the profile and flushes any unwritten data. // Stop stops the profile and flushes any unwritten data.
func (p *profile) Stop() { func (p *Profile) Stop() {
if !atomic.CompareAndSwapUint32(&p.stopped, 0, 1) { if !atomic.CompareAndSwapUint32(&p.stopped, 0, 1) {
// someone has already called close // someone has already called close
return return
@ -108,14 +119,14 @@ var started uint32
// Start starts a new profiling session. // Start starts a new profiling session.
// The caller should call the Stop method on the value returned // The caller should call the Stop method on the value returned
// to cleanly stop profiling. // to cleanly stop profiling.
func Start(options ...func(*profile)) interface { func Start(options ...func(*Profile)) interface {
Stop() Stop()
} { } {
if !atomic.CompareAndSwapUint32(&started, 0, 1) { if !atomic.CompareAndSwapUint32(&started, 0, 1) {
log.Fatal("profile: Start() already called") log.Fatal("profile: Start() already called")
} }
var prof profile var prof Profile
for _, option := range options { for _, option := range options {
option(&prof) option(&prof)
} }
@ -168,6 +179,23 @@ func Start(options ...func(*profile)) interface {
logf("profile: memory profiling disabled, %s", fn) logf("profile: memory profiling disabled, %s", fn)
} }
case mutexMode:
fn := filepath.Join(path, "mutex.pprof")
f, err := os.Create(fn)
if err != nil {
log.Fatalf("profile: could not create mutex profile %q: %v", fn, err)
}
enableMutexProfile()
logf("profile: mutex profiling enabled, %s", fn)
prof.closer = func() {
if mp := pprof.Lookup("mutex"); mp != nil {
mp.WriteTo(f, 0)
}
f.Close()
disableMutexProfile()
logf("profile: mutex profiling disabled, %s", fn)
}
case blockMode: case blockMode:
fn := filepath.Join(path, "block.pprof") fn := filepath.Join(path, "block.pprof")
f, err := os.Create(fn) f, err := os.Create(fn)

View File

@ -14,13 +14,20 @@ import (
type checkFn func(t *testing.T, stdout, stderr []byte, err error) type checkFn func(t *testing.T, stdout, stderr []byte, err error)
var profileTests = []struct { func TestProfile(t *testing.T) {
name string f, err := ioutil.TempFile("", "profile_test")
code string if err != nil {
checks []checkFn t.Fatal(err)
}{{ }
name: "default profile (cpu)", defer os.Remove(f.Name())
code: `
var profileTests = []struct {
name string
code string
checks []checkFn
}{{
name: "default profile (cpu)",
code: `
package main package main
import "github.com/pkg/profile" import "github.com/pkg/profile"
@ -29,14 +36,14 @@ func main() {
defer profile.Start().Stop() defer profile.Start().Stop()
} }
`, `,
checks: []checkFn{ checks: []checkFn{
NoStdout, NoStdout,
Stderr("profile: cpu profiling enabled"), Stderr("profile: cpu profiling enabled"),
NoErr, NoErr,
}, },
}, { }, {
name: "memory profile", name: "memory profile",
code: ` code: `
package main package main
import "github.com/pkg/profile" import "github.com/pkg/profile"
@ -45,14 +52,14 @@ func main() {
defer profile.Start(profile.MemProfile).Stop() defer profile.Start(profile.MemProfile).Stop()
} }
`, `,
checks: []checkFn{ checks: []checkFn{
NoStdout, NoStdout,
Stderr("profile: memory profiling enabled"), Stderr("profile: memory profiling enabled"),
NoErr, NoErr,
}, },
}, { }, {
name: "memory profile (rate 2048)", name: "memory profile (rate 2048)",
code: ` code: `
package main package main
import "github.com/pkg/profile" import "github.com/pkg/profile"
@ -61,14 +68,14 @@ func main() {
defer profile.Start(profile.MemProfileRate(2048)).Stop() defer profile.Start(profile.MemProfileRate(2048)).Stop()
} }
`, `,
checks: []checkFn{ checks: []checkFn{
NoStdout, NoStdout,
Stderr("profile: memory profiling enabled (rate 2048)"), Stderr("profile: memory profiling enabled (rate 2048)"),
NoErr, NoErr,
}, },
}, { }, {
name: "double start", name: "double start",
code: ` code: `
package main package main
import "github.com/pkg/profile" import "github.com/pkg/profile"
@ -78,14 +85,14 @@ func main() {
profile.Start() profile.Start()
} }
`, `,
checks: []checkFn{ checks: []checkFn{
NoStdout, NoStdout,
Stderr("cpu profiling enabled", "profile: Start() already called"), Stderr("cpu profiling enabled", "profile: Start() already called"),
Err, Err,
}, },
}, { }, {
name: "block profile", name: "block profile",
code: ` code: `
package main package main
import "github.com/pkg/profile" import "github.com/pkg/profile"
@ -94,14 +101,30 @@ func main() {
defer profile.Start(profile.BlockProfile).Stop() defer profile.Start(profile.BlockProfile).Stop()
} }
`, `,
checks: []checkFn{ checks: []checkFn{
NoStdout, NoStdout,
Stderr("profile: block profiling enabled"), Stderr("profile: block profiling enabled"),
NoErr, NoErr,
}, },
}, { }, {
name: "profile path", name: "mutex profile",
code: ` code: `
package main
import "github.com/pkg/profile"
func main() {
defer profile.Start(profile.MutexProfile).Stop()
}
`,
checks: []checkFn{
NoStdout,
Stderr("profile: mutex profiling enabled"),
NoErr,
},
}, {
name: "profile path",
code: `
package main package main
import "github.com/pkg/profile" import "github.com/pkg/profile"
@ -110,30 +133,30 @@ func main() {
defer profile.Start(profile.ProfilePath(".")).Stop() defer profile.Start(profile.ProfilePath(".")).Stop()
} }
`, `,
checks: []checkFn{ checks: []checkFn{
NoStdout, NoStdout,
Stderr("profile: cpu profiling enabled, cpu.pprof"), Stderr("profile: cpu profiling enabled, cpu.pprof"),
NoErr, NoErr,
}, },
}, { }, {
name: "profile path error", name: "profile path error",
code: ` code: `
package main package main
import "github.com/pkg/profile" import "github.com/pkg/profile"
func main() { func main() {
defer profile.Start(profile.ProfilePath("README.md")).Stop() defer profile.Start(profile.ProfilePath("` + f.Name() + `")).Stop()
} }
`, `,
checks: []checkFn{ checks: []checkFn{
NoStdout, NoStdout,
Stderr("could not create initial output"), Stderr("could not create initial output"),
Err, Err,
}, },
}, { }, {
name: "multiple profile sessions", name: "multiple profile sessions",
code: ` code: `
package main package main
import "github.com/pkg/profile" import "github.com/pkg/profile"
@ -143,21 +166,26 @@ func main() {
profile.Start(profile.MemProfile).Stop() profile.Start(profile.MemProfile).Stop()
profile.Start(profile.BlockProfile).Stop() profile.Start(profile.BlockProfile).Stop()
profile.Start(profile.CPUProfile).Stop() profile.Start(profile.CPUProfile).Stop()
profile.Start(profile.MutexProfile).Stop()
} }
`, `,
checks: []checkFn{ checks: []checkFn{
NoStdout, NoStdout,
Stderr("profile: cpu profiling enabled", Stderr("profile: cpu profiling enabled",
"profile: cpu profiling disabled", "profile: cpu profiling disabled",
"profile: memory profiling enabled", "profile: memory profiling enabled",
"profile: memory profiling disabled", "profile: memory profiling disabled",
"profile: block profiling enabled", "profile: block profiling enabled",
"profile: block profiling disabled"), "profile: block profiling disabled",
NoErr, "profile: cpu profiling enabled",
}, "profile: cpu profiling disabled",
}, { "profile: mutex profiling enabled",
name: "profile quiet", "profile: mutex profiling disabled"),
code: ` NoErr,
},
}, {
name: "profile quiet",
code: `
package main package main
import "github.com/pkg/profile" import "github.com/pkg/profile"
@ -166,10 +194,8 @@ func main() {
defer profile.Start(profile.Quiet).Stop() defer profile.Start(profile.Quiet).Stop()
} }
`, `,
checks: []checkFn{NoStdout, NoStderr, NoErr}, checks: []checkFn{NoStdout, NoStderr, NoErr},
}} }}
func TestProfile(t *testing.T) {
for _, tt := range profileTests { for _, tt := range profileTests {
t.Log(tt.name) t.Log(tt.name)
stdout, stderr, err := runTest(t, tt.code) stdout, stderr, err := runTest(t, tt.code)

View File

@ -4,8 +4,5 @@ package profile
import "runtime/trace" import "runtime/trace"
// Trace profile controls if execution tracing will be enabled. It disables any previous profiling settings.
func TraceProfile(p *profile) { p.mode = traceMode }
var startTrace = trace.Start var startTrace = trace.Start
var stopTrace = trace.Stop var stopTrace = trace.Stop

View File

@ -1,5 +1,3 @@
// +build go1.7
package profile_test package profile_test
import "github.com/pkg/profile" import "github.com/pkg/profile"

View File

@ -1,2 +1,3 @@
Dave Cheney <dave@cheney.net> Dave Cheney <dave@cheney.net>
Saulius Gurklys <s4uliu5@gmail.com> Saulius Gurklys <s4uliu5@gmail.com>
John Eikenberry <jae@zhar.net>

View File

@ -1,27 +1,44 @@
sftp sftp
---- ----
The `sftp` package provides support for file system operations on remote ssh servers using the SFTP subsystem. The `sftp` package provides support for file system operations on remote ssh
servers using the SFTP subsystem. It also implements an SFTP server for serving
files from the filesystem.
[![UNIX Build Status](https://travis-ci.org/pkg/sftp.svg?branch=master)](https://travis-ci.org/pkg/sftp) [![GoDoc](http://godoc.org/github.com/pkg/sftp?status.svg)](http://godoc.org/github.com/pkg/sftp) [![UNIX Build Status](https://travis-ci.org/pkg/sftp.svg?branch=master)](https://travis-ci.org/pkg/sftp) [![GoDoc](http://godoc.org/github.com/pkg/sftp?status.svg)](http://godoc.org/github.com/pkg/sftp)
usage and examples usage and examples
------------------ ------------------
See [godoc.org/github.com/pkg/sftp](http://godoc.org/github.com/pkg/sftp) for examples and usage. See [godoc.org/github.com/pkg/sftp](http://godoc.org/github.com/pkg/sftp) for
examples and usage.
The basic operation of the package mirrors the facilities of the [os](http://golang.org/pkg/os) package. The basic operation of the package mirrors the facilities of the
[os](http://golang.org/pkg/os) package.
The Walker interface for directory traversal is heavily inspired by Keith Rarick's [fs](http://godoc.org/github.com/kr/fs) package. The Walker interface for directory traversal is heavily inspired by Keith
Rarick's [fs](http://godoc.org/github.com/kr/fs) package.
roadmap roadmap
------- -------
* There is way too much duplication in the Client methods. If there was an unmarshal(interface{}) method this would reduce a heap of the duplication. * There is way too much duplication in the Client methods. If there was an
unmarshal(interface{}) method this would reduce a heap of the duplication.
contributing contributing
------------ ------------
We welcome pull requests, bug fixes and issue reports. We welcome pull requests, bug fixes and issue reports.
Before proposing a large change, first please discuss your change by raising an issue. Before proposing a large change, first please discuss your change by raising an
issue.
For API/code bugs, please include a small, self contained code example to
reproduce the issue. For pull requests, remember test coverage.
We try to handle issues and pull requests with a 0 open philosophy. That means
we will try to address the submission as soon as possible and will work toward
a resolution. If progress can no longer be made (eg. unreproducible bug) or
stops (eg. unresponsive submitter), we will close the bug.
Thanks.

View File

@ -203,7 +203,7 @@ func (c *Client) opendir(path string) (string, error) {
handle, _ := unmarshalString(data) handle, _ := unmarshalString(data)
return handle, nil return handle, nil
case ssh_FXP_STATUS: case ssh_FXP_STATUS:
return "", unmarshalStatus(id, data) return "", normaliseError(unmarshalStatus(id, data))
default: default:
return "", unimplementedPacketErr(typ) return "", unimplementedPacketErr(typ)
} }
@ -284,7 +284,7 @@ func (c *Client) ReadLink(p string) (string, error) {
filename, _ := unmarshalString(data) // ignore dummy attributes filename, _ := unmarshalString(data) // ignore dummy attributes
return filename, nil return filename, nil
case ssh_FXP_STATUS: case ssh_FXP_STATUS:
return "", unmarshalStatus(id, data) return "", normaliseError(unmarshalStatus(id, data))
default: default:
return "", unimplementedPacketErr(typ) return "", unimplementedPacketErr(typ)
} }
@ -439,7 +439,7 @@ func (c *Client) fstat(handle string) (*FileStat, error) {
attr, _ := unmarshalAttrs(data) attr, _ := unmarshalAttrs(data)
return attr, nil return attr, nil
case ssh_FXP_STATUS: case ssh_FXP_STATUS:
return nil, unmarshalStatus(id, data) return nil, normaliseError(unmarshalStatus(id, data))
default: default:
return nil, unimplementedPacketErr(typ) return nil, unimplementedPacketErr(typ)
} }
@ -495,7 +495,7 @@ func (c *Client) Remove(path string) error {
// some servers, *cough* osx *cough*, return EPERM, not ENODIR. // some servers, *cough* osx *cough*, return EPERM, not ENODIR.
// serv-u returns ssh_FX_FILE_IS_A_DIRECTORY // serv-u returns ssh_FX_FILE_IS_A_DIRECTORY
case ssh_FX_PERMISSION_DENIED, ssh_FX_FAILURE, ssh_FX_FILE_IS_A_DIRECTORY: case ssh_FX_PERMISSION_DENIED, ssh_FX_FAILURE, ssh_FX_FILE_IS_A_DIRECTORY:
return c.removeDirectory(path) return c.RemoveDirectory(path)
} }
} }
return err return err
@ -518,7 +518,8 @@ func (c *Client) removeFile(path string) error {
} }
} }
func (c *Client) removeDirectory(path string) error { // RemoveDirectory removes a directory path.
func (c *Client) RemoveDirectory(path string) error {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(sshFxpRmdirPacket{ typ, data, err := c.sendPacket(sshFxpRmdirPacket{
ID: id, ID: id,
@ -640,9 +641,10 @@ func (f *File) Name() string {
const maxConcurrentRequests = 64 const maxConcurrentRequests = 64
// Read reads up to len(b) bytes from the File. It returns the number of // Read reads up to len(b) bytes from the File. It returns the number of bytes
// bytes read and an error, if any. EOF is signaled by a zero count with // read and an error, if any. Read follows io.Reader semantics, so when Read
// err set to io.EOF. // encounters an error or EOF condition after successfully reading n > 0 bytes,
// it returns the number of bytes read.
func (f *File) Read(b []byte) (int, error) { func (f *File) Read(b []byte) (int, error) {
// Split the read into multiple maxPacket sized concurrent reads // Split the read into multiple maxPacket sized concurrent reads
// bounded by maxConcurrentRequests. This allows reads with a suitably // bounded by maxConcurrentRequests. This allows reads with a suitably
@ -651,7 +653,9 @@ func (f *File) Read(b []byte) (int, error) {
inFlight := 0 inFlight := 0
desiredInFlight := 1 desiredInFlight := 1
offset := f.offset offset := f.offset
ch := make(chan result, 1) // maxConcurrentRequests buffer to deal with broadcastErr() floods
// also must have a buffer of max value of (desiredInFlight - inFlight)
ch := make(chan result, maxConcurrentRequests)
type inflightRead struct { type inflightRead struct {
b []byte b []byte
offset uint64 offset uint64
@ -688,43 +692,39 @@ func (f *File) Read(b []byte) (int, error) {
if inFlight == 0 { if inFlight == 0 {
break break
} }
select { res := <-ch
case res := <-ch: inFlight--
inFlight-- if res.err != nil {
if res.err != nil { firstErr = offsetErr{offset: 0, err: res.err}
firstErr = offsetErr{offset: 0, err: res.err} continue
break }
} reqID, data := unmarshalUint32(res.data)
reqID, data := unmarshalUint32(res.data) req, ok := reqs[reqID]
req, ok := reqs[reqID] if !ok {
if !ok { firstErr = offsetErr{offset: 0, err: errors.Errorf("sid: %v not found", reqID)}
firstErr = offsetErr{offset: 0, err: errors.Errorf("sid: %v not found", reqID)} continue
break }
} delete(reqs, reqID)
delete(reqs, reqID) switch res.typ {
switch res.typ { case ssh_FXP_STATUS:
case ssh_FXP_STATUS: if firstErr.err == nil || req.offset < firstErr.offset {
if firstErr.err == nil || req.offset < firstErr.offset { firstErr = offsetErr{
firstErr = offsetErr{ offset: req.offset,
offset: req.offset, err: normaliseError(unmarshalStatus(reqID, res.data)),
err: normaliseError(unmarshalStatus(reqID, res.data)),
}
break
} }
case ssh_FXP_DATA:
l, data := unmarshalUint32(data)
n := copy(req.b, data[:l])
read += n
if n < len(req.b) {
sendReq(req.b[l:], req.offset+uint64(l))
}
if desiredInFlight < maxConcurrentRequests {
desiredInFlight++
}
default:
firstErr = offsetErr{offset: 0, err: unimplementedPacketErr(res.typ)}
break
} }
case ssh_FXP_DATA:
l, data := unmarshalUint32(data)
n := copy(req.b, data[:l])
read += n
if n < len(req.b) {
sendReq(req.b[l:], req.offset+uint64(l))
}
if desiredInFlight < maxConcurrentRequests {
desiredInFlight++
}
default:
firstErr = offsetErr{offset: 0, err: unimplementedPacketErr(res.typ)}
} }
} }
// If the error is anything other than EOF, then there // If the error is anything other than EOF, then there
@ -750,7 +750,8 @@ func (f *File) WriteTo(w io.Writer) (int64, error) {
offset := f.offset offset := f.offset
writeOffset := offset writeOffset := offset
fileSize := uint64(fi.Size()) fileSize := uint64(fi.Size())
ch := make(chan result, 1) // see comment on same line in Read() above
ch := make(chan result, maxConcurrentRequests)
type inflightRead struct { type inflightRead struct {
b []byte b []byte
offset uint64 offset uint64
@ -777,83 +778,92 @@ func (f *File) WriteTo(w io.Writer) (int64, error) {
var copied int64 var copied int64
for firstErr.err == nil || inFlight > 0 { for firstErr.err == nil || inFlight > 0 {
for inFlight < desiredInFlight && firstErr.err == nil { if firstErr.err == nil {
b := make([]byte, f.c.maxPacket) for inFlight+len(pendingWrites) < desiredInFlight {
sendReq(b, offset) b := make([]byte, f.c.maxPacket)
offset += uint64(f.c.maxPacket) sendReq(b, offset)
if offset > fileSize { offset += uint64(f.c.maxPacket)
desiredInFlight = 1 if offset > fileSize {
desiredInFlight = 1
}
} }
} }
if inFlight == 0 { if inFlight == 0 {
if firstErr.err == nil && len(pendingWrites) > 0 {
return copied, errors.New("internal inconsistency")
}
break break
} }
select { res := <-ch
case res := <-ch: inFlight--
inFlight-- if res.err != nil {
if res.err != nil { firstErr = offsetErr{offset: 0, err: res.err}
firstErr = offsetErr{offset: 0, err: res.err} continue
break }
reqID, data := unmarshalUint32(res.data)
req, ok := reqs[reqID]
if !ok {
firstErr = offsetErr{offset: 0, err: errors.Errorf("sid: %v not found", reqID)}
continue
}
delete(reqs, reqID)
switch res.typ {
case ssh_FXP_STATUS:
if firstErr.err == nil || req.offset < firstErr.offset {
firstErr = offsetErr{offset: req.offset, err: normaliseError(unmarshalStatus(reqID, res.data))}
} }
reqID, data := unmarshalUint32(res.data) case ssh_FXP_DATA:
req, ok := reqs[reqID] l, data := unmarshalUint32(data)
if !ok { if req.offset == writeOffset {
firstErr = offsetErr{offset: 0, err: errors.Errorf("sid: %v not found", reqID)} nbytes, err := w.Write(data)
break copied += int64(nbytes)
} if err != nil {
delete(reqs, reqID) // We will never receive another DATA with offset==writeOffset, so
switch res.typ { // the loop will drain inFlight and then exit.
case ssh_FXP_STATUS: firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: err}
if firstErr.err == nil || req.offset < firstErr.offset {
firstErr = offsetErr{offset: req.offset, err: normaliseError(unmarshalStatus(reqID, res.data))}
break break
} }
case ssh_FXP_DATA: if nbytes < int(l) {
l, data := unmarshalUint32(data) firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: io.ErrShortWrite}
if req.offset == writeOffset { break
nbytes, err := w.Write(data) }
copied += int64(nbytes) switch {
case offset > fileSize:
desiredInFlight = 1
case desiredInFlight < maxConcurrentRequests:
desiredInFlight++
}
writeOffset += uint64(nbytes)
for {
pendingData, ok := pendingWrites[writeOffset]
if !ok {
break
}
// Give go a chance to free the memory.
delete(pendingWrites, writeOffset)
nbytes, err := w.Write(pendingData)
// Do not move writeOffset on error so subsequent iterations won't trigger
// any writes.
if err != nil { if err != nil {
firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: err} firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: err}
break break
} }
if nbytes < int(l) { if nbytes < len(pendingData) {
firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: io.ErrShortWrite} firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: io.ErrShortWrite}
break break
} }
switch {
case offset > fileSize:
desiredInFlight = 1
case desiredInFlight < maxConcurrentRequests:
desiredInFlight++
}
writeOffset += uint64(nbytes) writeOffset += uint64(nbytes)
for pendingData, ok := pendingWrites[writeOffset]; ok; pendingData, ok = pendingWrites[writeOffset] {
nbytes, err := w.Write(pendingData)
if err != nil {
firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: err}
break
}
if nbytes < len(pendingData) {
firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: io.ErrShortWrite}
break
}
writeOffset += uint64(nbytes)
inFlight--
}
} else {
// Don't write the data yet because
// this response came in out of order
// and we need to wait for responses
// for earlier segments of the file.
inFlight++ // Pending writes should still be considered inFlight.
pendingWrites[req.offset] = data
} }
default: } else {
firstErr = offsetErr{offset: 0, err: unimplementedPacketErr(res.typ)} // Don't write the data yet because
break // this response came in out of order
// and we need to wait for responses
// for earlier segments of the file.
pendingWrites[req.offset] = data
} }
default:
firstErr = offsetErr{offset: 0, err: unimplementedPacketErr(res.typ)}
} }
} }
if firstErr.err != io.EOF { if firstErr.err != io.EOF {
@ -883,7 +893,8 @@ func (f *File) Write(b []byte) (int, error) {
inFlight := 0 inFlight := 0
desiredInFlight := 1 desiredInFlight := 1
offset := f.offset offset := f.offset
ch := make(chan result, 1) // see comment on same line in Read() above
ch := make(chan result, maxConcurrentRequests)
var firstErr error var firstErr error
written := len(b) written := len(b)
for len(b) > 0 || inFlight > 0 { for len(b) > 0 || inFlight > 0 {
@ -905,28 +916,25 @@ func (f *File) Write(b []byte) (int, error) {
if inFlight == 0 { if inFlight == 0 {
break break
} }
select { res := <-ch
case res := <-ch: inFlight--
inFlight-- if res.err != nil {
if res.err != nil { firstErr = res.err
firstErr = res.err continue
}
switch res.typ {
case ssh_FXP_STATUS:
id, _ := unmarshalUint32(res.data)
err := normaliseError(unmarshalStatus(id, res.data))
if err != nil && firstErr == nil {
firstErr = err
break break
} }
switch res.typ { if desiredInFlight < maxConcurrentRequests {
case ssh_FXP_STATUS: desiredInFlight++
id, _ := unmarshalUint32(res.data)
err := normaliseError(unmarshalStatus(id, res.data))
if err != nil && firstErr == nil {
firstErr = err
break
}
if desiredInFlight < maxConcurrentRequests {
desiredInFlight++
}
default:
firstErr = unimplementedPacketErr(res.typ)
break
} }
default:
firstErr = unimplementedPacketErr(res.typ)
} }
} }
// If error is non-nil, then there may be gaps in the data written to // If error is non-nil, then there may be gaps in the data written to
@ -946,7 +954,8 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
inFlight := 0 inFlight := 0
desiredInFlight := 1 desiredInFlight := 1
offset := f.offset offset := f.offset
ch := make(chan result, 1) // see comment on same line in Read() above
ch := make(chan result, maxConcurrentRequests)
var firstErr error var firstErr error
read := int64(0) read := int64(0)
b := make([]byte, f.c.maxPacket) b := make([]byte, f.c.maxPacket)
@ -971,28 +980,25 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
if inFlight == 0 { if inFlight == 0 {
break break
} }
select { res := <-ch
case res := <-ch: inFlight--
inFlight-- if res.err != nil {
if res.err != nil { firstErr = res.err
firstErr = res.err continue
}
switch res.typ {
case ssh_FXP_STATUS:
id, _ := unmarshalUint32(res.data)
err := normaliseError(unmarshalStatus(id, res.data))
if err != nil && firstErr == nil {
firstErr = err
break break
} }
switch res.typ { if desiredInFlight < maxConcurrentRequests {
case ssh_FXP_STATUS: desiredInFlight++
id, _ := unmarshalUint32(res.data)
err := normaliseError(unmarshalStatus(id, res.data))
if err != nil && firstErr == nil {
firstErr = err
break
}
if desiredInFlight < maxConcurrentRequests {
desiredInFlight++
}
default:
firstErr = unimplementedPacketErr(res.typ)
break
} }
default:
firstErr = unimplementedPacketErr(res.typ)
} }
} }
if firstErr == io.EOF { if firstErr == io.EOF {
@ -1080,10 +1086,7 @@ func unmarshalStatus(id uint32, data []byte) error {
return &unexpectedIDErr{id, sid} return &unexpectedIDErr{id, sid}
} }
code, data := unmarshalUint32(data) code, data := unmarshalUint32(data)
msg, data, err := unmarshalStringSafe(data) msg, data, _ := unmarshalStringSafe(data)
if err != nil {
return err
}
lang, _, _ := unmarshalStringSafe(data) lang, _, _ := unmarshalStringSafe(data)
return &StatusError{ return &StatusError{
Code: code, Code: code,

View File

@ -29,14 +29,14 @@ func TestClientStatVFS(t *testing.T) {
// check some stats // check some stats
if vfs.Frsize != uint64(s.Frsize) { if vfs.Frsize != uint64(s.Frsize) {
t.Fatal("fr_size does not match, expected: %v, got: %v", s.Frsize, vfs.Frsize) t.Fatalf("fr_size does not match, expected: %v, got: %v", s.Frsize, vfs.Frsize)
} }
if vfs.Bsize != uint64(s.Bsize) { if vfs.Bsize != uint64(s.Bsize) {
t.Fatal("f_bsize does not match, expected: %v, got: %v", s.Bsize, vfs.Bsize) t.Fatalf("f_bsize does not match, expected: %v, got: %v", s.Bsize, vfs.Bsize)
} }
if vfs.Namemax != uint64(s.Namelen) { if vfs.Namemax != uint64(s.Namelen) {
t.Fatal("f_namemax does not match, expected: %v, got: %v", s.Namelen, vfs.Namemax) t.Fatalf("f_namemax does not match, expected: %v, got: %v", s.Namelen, vfs.Namemax)
} }
} }

View File

@ -4,7 +4,10 @@ package sftp
// enable with -integration // enable with -integration
import ( import (
"bytes"
"crypto/sha1" "crypto/sha1"
"encoding"
"errors"
"flag" "flag"
"io" "io"
"io/ioutil" "io/ioutil"
@ -22,6 +25,8 @@ import (
"testing/quick" "testing/quick"
"time" "time"
"sort"
"github.com/kr/fs" "github.com/kr/fs"
) )
@ -87,7 +92,7 @@ func (w delayedWriter) Close() error {
// netPipe provides a pair of io.ReadWriteClosers connected to each other. // netPipe provides a pair of io.ReadWriteClosers connected to each other.
// The functions is identical to os.Pipe with the exception that netPipe // The functions is identical to os.Pipe with the exception that netPipe
// provides the Read/Close guarentees that os.File derrived pipes do not. // provides the Read/Close guarantees that os.File derrived pipes do not.
func netPipe(t testing.TB) (io.ReadWriteCloser, io.ReadWriteCloser) { func netPipe(t testing.TB) (io.ReadWriteCloser, io.ReadWriteCloser) {
type result struct { type result struct {
net.Conn net.Conn
@ -1159,6 +1164,197 @@ func TestClientWrite(t *testing.T) {
} }
} }
// ReadFrom is basically Write with io.Reader as the arg
func TestClientReadFrom(t *testing.T) {
sftp, cmd := testClient(t, READWRITE, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
d, err := ioutil.TempDir("", "sftptest")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(d)
f := path.Join(d, "writeTest")
w, err := sftp.Create(f)
if err != nil {
t.Fatal(err)
}
defer w.Close()
for _, tt := range clientWriteTests {
got, err := w.ReadFrom(bytes.NewReader(make([]byte, tt.n)))
if err != nil {
t.Fatal(err)
}
if got != int64(tt.n) {
t.Errorf("Write(%v): wrote: want: %v, got %v", tt.n, tt.n, got)
}
fi, err := os.Stat(f)
if err != nil {
t.Fatal(err)
}
if total := fi.Size(); total != tt.total {
t.Errorf("Write(%v): size: want: %v, got %v", tt.n, tt.total, total)
}
}
}
// Issue #145 in github
// Deadlock in ReadFrom when network drops after 1 good packet.
// Deadlock would occur anytime desiredInFlight-inFlight==2 and 2 errors
// occured in a row. The channel to report the errors only had a buffer
// of 1 and 2 would be sent.
var fakeNetErr = errors.New("Fake network issue")
func TestClientReadFromDeadlock(t *testing.T) {
clientWriteDeadlock(t, 1, func(f *File) {
b := make([]byte, 32768*4)
content := bytes.NewReader(b)
n, err := f.ReadFrom(content)
if n != 0 {
t.Fatal("Write should return 0", n)
}
if err != fakeNetErr {
t.Fatal("Didn't recieve correct error", err)
}
})
}
// Write has exact same problem
func TestClientWriteDeadlock(t *testing.T) {
clientWriteDeadlock(t, 1, func(f *File) {
b := make([]byte, 32768*4)
n, err := f.Write(b)
if n != 0 {
t.Fatal("Write should return 0", n)
}
if err != fakeNetErr {
t.Fatal("Didn't recieve correct error", err)
}
})
}
// shared body for both previous tests
func clientWriteDeadlock(t *testing.T, N int, badfunc func(*File)) {
if !*testServerImpl {
t.Skipf("skipping without -testserver")
}
sftp, cmd := testClient(t, READWRITE, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
d, err := ioutil.TempDir("", "sftptest")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(d)
f := path.Join(d, "writeTest")
w, err := sftp.Create(f)
if err != nil {
t.Fatal(err)
}
defer w.Close()
// Override sendPacket with failing version
// Replicates network error/drop part way through (after 1 good packet)
count := 0
sendPacketTest := func(w io.Writer, m encoding.BinaryMarshaler) error {
count++
if count > N {
return fakeNetErr
}
return sendPacket(w, m)
}
sftp.clientConn.conn.sendPacketTest = sendPacketTest
defer func() {
sftp.clientConn.conn.sendPacketTest = nil
}()
// this locked (before the fix)
badfunc(w)
}
// Read/WriteTo has this issue as well
func TestClientReadDeadlock(t *testing.T) {
clientReadDeadlock(t, 1, func(f *File) {
b := make([]byte, 32768*4)
n, err := f.Read(b)
if n != 0 {
t.Fatal("Write should return 0", n)
}
if err != fakeNetErr {
t.Fatal("Didn't recieve correct error", err)
}
})
}
func TestClientWriteToDeadlock(t *testing.T) {
clientReadDeadlock(t, 2, func(f *File) {
b := make([]byte, 32768*4)
buf := bytes.NewBuffer(b)
n, err := f.WriteTo(buf)
if n != 32768 {
t.Fatal("Write should return 0", n)
}
if err != fakeNetErr {
t.Fatal("Didn't recieve correct error", err)
}
})
}
func clientReadDeadlock(t *testing.T, N int, badfunc func(*File)) {
if !*testServerImpl {
t.Skipf("skipping without -testserver")
}
sftp, cmd := testClient(t, READWRITE, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
d, err := ioutil.TempDir("", "sftptest")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(d)
f := path.Join(d, "writeTest")
w, err := sftp.Create(f)
if err != nil {
t.Fatal(err)
}
// write the data for the read tests
b := make([]byte, 32768*4)
w.Write(b)
defer w.Close()
// open new copy of file for read tests
r, err := sftp.Open(f)
if err != nil {
t.Fatal(err)
}
defer r.Close()
// Override sendPacket with failing version
// Replicates network error/drop part way through (after 1 good packet)
count := 0
sendPacketTest := func(w io.Writer, m encoding.BinaryMarshaler) error {
count++
if count > N {
return fakeNetErr
}
return sendPacket(w, m)
}
sftp.clientConn.conn.sendPacketTest = sendPacketTest
defer func() {
sftp.clientConn.conn.sendPacketTest = nil
}()
// this locked (before the fix)
badfunc(r)
}
// taken from github.com/kr/fs/walk_test.go // taken from github.com/kr/fs/walk_test.go
type Node struct { type Node struct {
@ -1330,6 +1526,169 @@ func TestClientWalk(t *testing.T) {
} }
} }
type MatchTest struct {
pattern, s string
match bool
err error
}
var matchTests = []MatchTest{
{"abc", "abc", true, nil},
{"*", "abc", true, nil},
{"*c", "abc", true, nil},
{"a*", "a", true, nil},
{"a*", "abc", true, nil},
{"a*", "ab/c", false, nil},
{"a*/b", "abc/b", true, nil},
{"a*/b", "a/c/b", false, nil},
{"a*b*c*d*e*/f", "axbxcxdxe/f", true, nil},
{"a*b*c*d*e*/f", "axbxcxdxexxx/f", true, nil},
{"a*b*c*d*e*/f", "axbxcxdxe/xxx/f", false, nil},
{"a*b*c*d*e*/f", "axbxcxdxexxx/fff", false, nil},
{"a*b?c*x", "abxbbxdbxebxczzx", true, nil},
{"a*b?c*x", "abxbbxdbxebxczzy", false, nil},
{"ab[c]", "abc", true, nil},
{"ab[b-d]", "abc", true, nil},
{"ab[e-g]", "abc", false, nil},
{"ab[^c]", "abc", false, nil},
{"ab[^b-d]", "abc", false, nil},
{"ab[^e-g]", "abc", true, nil},
{"a\\*b", "a*b", true, nil},
{"a\\*b", "ab", false, nil},
{"a?b", "a☺b", true, nil},
{"a[^a]b", "a☺b", true, nil},
{"a???b", "a☺b", false, nil},
{"a[^a][^a][^a]b", "a☺b", false, nil},
{"[a-ζ]*", "α", true, nil},
{"*[a-ζ]", "A", false, nil},
{"a?b", "a/b", false, nil},
{"a*b", "a/b", false, nil},
{"[\\]a]", "]", true, nil},
{"[\\-]", "-", true, nil},
{"[x\\-]", "x", true, nil},
{"[x\\-]", "-", true, nil},
{"[x\\-]", "z", false, nil},
{"[\\-x]", "x", true, nil},
{"[\\-x]", "-", true, nil},
{"[\\-x]", "a", false, nil},
{"[]a]", "]", false, ErrBadPattern},
{"[-]", "-", false, ErrBadPattern},
{"[x-]", "x", false, ErrBadPattern},
{"[x-]", "-", false, ErrBadPattern},
{"[x-]", "z", false, ErrBadPattern},
{"[-x]", "x", false, ErrBadPattern},
{"[-x]", "-", false, ErrBadPattern},
{"[-x]", "a", false, ErrBadPattern},
{"\\", "a", false, ErrBadPattern},
{"[a-b-c]", "a", false, ErrBadPattern},
{"[", "a", false, ErrBadPattern},
{"[^", "a", false, ErrBadPattern},
{"[^bc", "a", false, ErrBadPattern},
{"a[", "a", false, nil},
{"a[", "ab", false, ErrBadPattern},
{"*x", "xxx", true, nil},
}
func errp(e error) string {
if e == nil {
return "<nil>"
}
return e.Error()
}
// contains returns true if vector contains the string s.
func contains(vector []string, s string) bool {
for _, elem := range vector {
if elem == s {
return true
}
}
return false
}
var globTests = []struct {
pattern, result string
}{
{"match.go", "./match.go"},
{"mat?h.go", "./match.go"},
{"ma*ch.go", "./match.go"},
{"../*/match.go", "../sftp/match.go"},
}
type globTest struct {
pattern string
matches []string
}
func (test *globTest) buildWant(root string) []string {
var want []string
for _, m := range test.matches {
want = append(want, root+filepath.FromSlash(m))
}
sort.Strings(want)
return want
}
func TestMatch(t *testing.T) {
for _, tt := range matchTests {
pattern := tt.pattern
s := tt.s
ok, err := Match(pattern, s)
if ok != tt.match || err != tt.err {
t.Errorf("Match(%#q, %#q) = %v, %q want %v, %q", pattern, s, ok, errp(err), tt.match, errp(tt.err))
}
}
}
func TestGlob(t *testing.T) {
sftp, cmd := testClient(t, READONLY, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
for _, tt := range globTests {
pattern := tt.pattern
result := tt.result
matches, err := sftp.Glob(pattern)
if err != nil {
t.Errorf("Glob error for %q: %s", pattern, err)
continue
}
if !contains(matches, result) {
t.Errorf("Glob(%#q) = %#v want %v", pattern, matches, result)
}
}
for _, pattern := range []string{"no_match", "../*/no_match"} {
matches, err := sftp.Glob(pattern)
if err != nil {
t.Errorf("Glob error for %q: %s", pattern, err)
continue
}
if len(matches) != 0 {
t.Errorf("Glob(%#q) = %#v want []", pattern, matches)
}
}
}
func TestGlobError(t *testing.T) {
sftp, cmd := testClient(t, READONLY, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
_, err := sftp.Glob("[7]")
if err != nil {
t.Error("expected error for bad pattern; got none")
}
}
func TestGlobUNC(t *testing.T) {
sftp, cmd := testClient(t, READONLY, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
// Just make sure this runs without crashing for now.
// See issue 15879.
sftp.Glob(`\\?\C:\*`)
}
// sftp/issue/42, abrupt server hangup would result in client hangs. // sftp/issue/42, abrupt server hangup would result in client hangs.
func TestServerRoughDisconnect(t *testing.T) { func TestServerRoughDisconnect(t *testing.T) {
if *testServerImpl { if *testServerImpl {
@ -1352,6 +1711,35 @@ func TestServerRoughDisconnect(t *testing.T) {
io.Copy(ioutil.Discard, f) io.Copy(ioutil.Discard, f)
} }
// sftp/issue/181, abrupt server hangup would result in client hangs.
// due to broadcastErr filling up the request channel
// this reproduces it about 50% of the time
func TestServerRoughDisconnect2(t *testing.T) {
if *testServerImpl {
t.Skipf("skipping with -testserver")
}
sftp, cmd := testClient(t, READONLY, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
f, err := sftp.Open("/dev/zero")
if err != nil {
t.Fatal(err)
}
defer f.Close()
b := make([]byte, 32768*100)
go func() {
time.Sleep(1 * time.Millisecond)
cmd.Process.Kill()
}()
for {
_, err = f.Read(b)
if err != nil {
break
}
}
}
// sftp/issue/26 writing to a read only file caused client to loop. // sftp/issue/26 writing to a read only file caused client to loop.
func TestClientWriteToROFile(t *testing.T) { func TestClientWriteToROFile(t *testing.T) {
sftp, cmd := testClient(t, READWRITE, NO_DELAY) sftp, cmd := testClient(t, READWRITE, NO_DELAY)
@ -1375,7 +1763,7 @@ func benchmarkRead(b *testing.B, bufsize int, delay time.Duration) {
// open sftp client // open sftp client
sftp, cmd := testClient(b, READONLY, delay) sftp, cmd := testClient(b, READONLY, delay)
defer cmd.Wait() defer cmd.Wait()
defer sftp.Close() // defer sftp.Close()
buf := make([]byte, bufsize) buf := make([]byte, bufsize)
@ -1453,7 +1841,7 @@ func benchmarkWrite(b *testing.B, bufsize int, delay time.Duration) {
// open sftp client // open sftp client
sftp, cmd := testClient(b, false, delay) sftp, cmd := testClient(b, false, delay)
defer cmd.Wait() defer cmd.Wait()
defer sftp.Close() // defer sftp.Close()
data := make([]byte, size) data := make([]byte, size)
@ -1543,6 +1931,88 @@ func BenchmarkWrite4MiBDelay150Msec(b *testing.B) {
benchmarkWrite(b, 4*1024*1024, 150*time.Millisecond) benchmarkWrite(b, 4*1024*1024, 150*time.Millisecond)
} }
func benchmarkReadFrom(b *testing.B, bufsize int, delay time.Duration) {
size := 10*1024*1024 + 123 // ~10MiB
// open sftp client
sftp, cmd := testClient(b, false, delay)
defer cmd.Wait()
// defer sftp.Close()
data := make([]byte, size)
b.ResetTimer()
b.SetBytes(int64(size))
for i := 0; i < b.N; i++ {
f, err := ioutil.TempFile("", "sftptest")
if err != nil {
b.Fatal(err)
}
defer os.Remove(f.Name())
f2, err := sftp.Create(f.Name())
if err != nil {
b.Fatal(err)
}
defer f2.Close()
f2.ReadFrom(bytes.NewReader(data))
f2.Close()
fi, err := os.Stat(f.Name())
if err != nil {
b.Fatal(err)
}
if fi.Size() != int64(size) {
b.Fatalf("wrong file size: want %d, got %d", size, fi.Size())
}
os.Remove(f.Name())
}
}
func BenchmarkReadFrom1k(b *testing.B) {
benchmarkReadFrom(b, 1*1024, NO_DELAY)
}
func BenchmarkReadFrom16k(b *testing.B) {
benchmarkReadFrom(b, 16*1024, NO_DELAY)
}
func BenchmarkReadFrom32k(b *testing.B) {
benchmarkReadFrom(b, 32*1024, NO_DELAY)
}
func BenchmarkReadFrom128k(b *testing.B) {
benchmarkReadFrom(b, 128*1024, NO_DELAY)
}
func BenchmarkReadFrom512k(b *testing.B) {
benchmarkReadFrom(b, 512*1024, NO_DELAY)
}
func BenchmarkReadFrom1MiB(b *testing.B) {
benchmarkReadFrom(b, 1024*1024, NO_DELAY)
}
func BenchmarkReadFrom4MiB(b *testing.B) {
benchmarkReadFrom(b, 4*1024*1024, NO_DELAY)
}
func BenchmarkReadFrom4MiBDelay10Msec(b *testing.B) {
benchmarkReadFrom(b, 4*1024*1024, 10*time.Millisecond)
}
func BenchmarkReadFrom4MiBDelay50Msec(b *testing.B) {
benchmarkReadFrom(b, 4*1024*1024, 50*time.Millisecond)
}
func BenchmarkReadFrom4MiBDelay150Msec(b *testing.B) {
benchmarkReadFrom(b, 4*1024*1024, 150*time.Millisecond)
}
func benchmarkCopyDown(b *testing.B, fileSize int64, delay time.Duration) { func benchmarkCopyDown(b *testing.B, fileSize int64, delay time.Duration) {
// Create a temp file and fill it with zero's. // Create a temp file and fill it with zero's.
src, err := ioutil.TempFile("", "sftptest") src, err := ioutil.TempFile("", "sftptest")
@ -1568,7 +2038,7 @@ func benchmarkCopyDown(b *testing.B, fileSize int64, delay time.Duration) {
sftp, cmd := testClient(b, READONLY, delay) sftp, cmd := testClient(b, READONLY, delay)
defer cmd.Wait() defer cmd.Wait()
defer sftp.Close() // defer sftp.Close()
b.ResetTimer() b.ResetTimer()
b.SetBytes(fileSize) b.SetBytes(fileSize)
@ -1641,7 +2111,7 @@ func benchmarkCopyUp(b *testing.B, fileSize int64, delay time.Duration) {
sftp, cmd := testClient(b, false, delay) sftp, cmd := testClient(b, false, delay)
defer cmd.Wait() defer cmd.Wait()
defer sftp.Close() // defer sftp.Close()
b.ResetTimer() b.ResetTimer()
b.SetBytes(fileSize) b.SetBytes(fileSize)

View File

@ -116,7 +116,9 @@ func TestUnmarshalStatus(t *testing.T) {
desc: "missing error message and language tag", desc: "missing error message and language tag",
reqID: 1, reqID: 1,
status: idCode, status: idCode,
want: errShortPacket, want: &StatusError{
Code: ssh_FX_FAILURE,
},
}, },
{ {
desc: "missing language tag", desc: "missing language tag",

View File

@ -14,6 +14,8 @@ type conn struct {
io.Reader io.Reader
io.WriteCloser io.WriteCloser
sync.Mutex // used to serialise writes to sendPacket sync.Mutex // used to serialise writes to sendPacket
// sendPacketTest is needed to replicate packet issues in testing
sendPacketTest func(w io.Writer, m encoding.BinaryMarshaler) error
} }
func (c *conn) recvPacket() (uint8, []byte, error) { func (c *conn) recvPacket() (uint8, []byte, error) {
@ -23,6 +25,9 @@ func (c *conn) recvPacket() (uint8, []byte, error) {
func (c *conn) sendPacket(m encoding.BinaryMarshaler) error { func (c *conn) sendPacket(m encoding.BinaryMarshaler) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
if c.sendPacketTest != nil {
return c.sendPacketTest(c, m)
}
return sendPacket(c, m) return sendPacket(c, m)
} }
@ -50,7 +55,11 @@ func (c *clientConn) loop() {
// recv continuously reads from the server and forwards responses to the // recv continuously reads from the server and forwards responses to the
// appropriate channel. // appropriate channel.
func (c *clientConn) recv() error { func (c *clientConn) recv() error {
defer c.conn.Close() defer func() {
c.conn.Lock()
c.conn.Close()
c.conn.Unlock()
}()
for { for {
typ, data, err := c.recvPacket() typ, data, err := c.recvPacket()
if err != nil { if err != nil {
@ -93,11 +102,13 @@ func (c *clientConn) sendPacket(p idmarshaler) (byte, []byte, error) {
func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) { func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) {
c.Lock() c.Lock()
c.inflight[p.id()] = ch c.inflight[p.id()] = ch
c.Unlock()
if err := c.conn.sendPacket(p); err != nil { if err := c.conn.sendPacket(p); err != nil {
c.Lock()
delete(c.inflight, p.id()) delete(c.inflight, p.id())
c.Unlock()
ch <- result{err: err} ch <- result{err: err}
} }
c.Unlock()
} }
// broadcastErr sends an error to all goroutines waiting for a response. // broadcastErr sends an error to all goroutines waiting for a response.
@ -117,6 +128,6 @@ type serverConn struct {
conn conn
} }
func (s *serverConn) sendError(p id, err error) error { func (s *serverConn) sendError(p ider, err error) error {
return s.sendPacket(statusFromError(p, err)) return s.sendPacket(statusFromError(p, err))
} }

View File

@ -5,12 +5,16 @@ import (
"log" "log"
"os" "os"
"os/exec" "os/exec"
"path"
"strings"
"github.com/pkg/sftp" "github.com/pkg/sftp"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
func Example(conn *ssh.Client) { func Example() {
var conn *ssh.Client
// open an SFTP session over an existing ssh connection. // open an SFTP session over an existing ssh connection.
sftp, err := sftp.NewClient(conn) sftp, err := sftp.NewClient(conn)
if err != nil { if err != nil {
@ -88,3 +92,44 @@ func ExampleNewClientPipe() {
// close the connection // close the connection
client.Close() client.Close()
} }
func ExampleClient_Mkdir_parents() {
// Example of mimicing 'mkdir --parents'; I.E. recursively create
// directoryies and don't error if any directories already exists.
var conn *ssh.Client
client, err := sftp.NewClient(conn)
if err != nil {
log.Fatal(err)
}
defer client.Close()
ssh_fx_failure := uint32(4)
mkdirParents := func(client *sftp.Client, dir string) (err error) {
var parents string
for _, name := range strings.Split(dir, "/") {
parents = path.Join(parents, name)
err = client.Mkdir(parents)
if status, ok := err.(*sftp.StatusError); ok {
if status.Code == ssh_fx_failure {
var fi os.FileInfo
fi, err = client.Stat(parents)
if err == nil {
if !fi.IsDir() {
return fmt.Errorf("File exists: %s", parents)
}
}
}
}
if err != nil {
break
}
}
return err
}
err = mkdirParents(client, "/tmp/foo/bar")
if err != nil {
log.Fatal(err)
}
}

View File

@ -42,6 +42,7 @@ func main() {
config := ssh.ClientConfig{ config := ssh.ClientConfig{
User: *USER, User: *USER,
Auth: auths, Auth: auths,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
} }
addr := fmt.Sprintf("%s:%d", *HOST, *PORT) addr := fmt.Sprintf("%s:%d", *HOST, *PORT)
conn, err := ssh.Dial("tcp", addr, &config) conn, err := ssh.Dial("tcp", addr, &config)

View File

@ -42,6 +42,7 @@ func main() {
config := ssh.ClientConfig{ config := ssh.ClientConfig{
User: *USER, User: *USER,
Auth: auths, Auth: auths,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
} }
addr := fmt.Sprintf("%s:%d", *HOST, *PORT) addr := fmt.Sprintf("%s:%d", *HOST, *PORT)
conn, err := ssh.Dial("tcp", addr, &config) conn, err := ssh.Dial("tcp", addr, &config)

View File

@ -0,0 +1,131 @@
// An example SFTP server implementation using the golang SSH package.
// Serves the whole filesystem visible to the user, and has a hard-coded username and password,
// so not for real use!
package main
import (
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
// Based on example server code from golang.org/x/crypto/ssh and server_standalone
func main() {
var (
readOnly bool
debugStderr bool
)
flag.BoolVar(&readOnly, "R", false, "read-only server")
flag.BoolVar(&debugStderr, "e", false, "debug to stderr")
flag.Parse()
debugStream := ioutil.Discard
if debugStderr {
debugStream = os.Stderr
}
// An SSH server is represented by a ServerConfig, which holds
// certificate details and handles authentication of ServerConns.
config := &ssh.ServerConfig{
PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
// Should use constant-time compare (or better, salt+hash) in
// a production setting.
fmt.Fprintf(debugStream, "Login: %s\n", c.User())
if c.User() == "testuser" && string(pass) == "tiger" {
return nil, nil
}
return nil, fmt.Errorf("password rejected for %q", c.User())
},
}
privateBytes, err := ioutil.ReadFile("id_rsa")
if err != nil {
log.Fatal("Failed to load private key", err)
}
private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil {
log.Fatal("Failed to parse private key", err)
}
config.AddHostKey(private)
// Once a ServerConfig has been configured, connections can be
// accepted.
listener, err := net.Listen("tcp", "0.0.0.0:2022")
if err != nil {
log.Fatal("failed to listen for connection", err)
}
fmt.Printf("Listening on %v\n", listener.Addr())
nConn, err := listener.Accept()
if err != nil {
log.Fatal("failed to accept incoming connection", err)
}
// Before use, a handshake must be performed on the incoming net.Conn.
sconn, chans, reqs, err := ssh.NewServerConn(nConn, config)
if err != nil {
log.Fatal("failed to handshake", err)
}
log.Println("login detected:", sconn.User())
fmt.Fprintf(debugStream, "SSH server established\n")
// The incoming Request channel must be serviced.
go ssh.DiscardRequests(reqs)
// Service the incoming Channel channel.
for newChannel := range chans {
// Channels have a type, depending on the application level
// protocol intended. In the case of an SFTP session, this is "subsystem"
// with a payload string of "<length=4>sftp"
fmt.Fprintf(debugStream, "Incoming channel: %s\n", newChannel.ChannelType())
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
fmt.Fprintf(debugStream, "Unknown channel type: %s\n", newChannel.ChannelType())
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
log.Fatal("could not accept channel.", err)
}
fmt.Fprintf(debugStream, "Channel accepted\n")
// Sessions have out-of-band requests such as "shell",
// "pty-req" and "env". Here we handle only the
// "subsystem" request.
go func(in <-chan *ssh.Request) {
for req := range in {
fmt.Fprintf(debugStream, "Request: %v\n", req.Type)
ok := false
switch req.Type {
case "subsystem":
fmt.Fprintf(debugStream, "Subsystem: %s\n", req.Payload[4:])
if string(req.Payload[4:]) == "sftp" {
ok = true
}
}
fmt.Fprintf(debugStream, " - accepted: %v\n", ok)
req.Reply(ok, nil)
}
}(requests)
root := sftp.InMemHandler()
server := sftp.NewRequestServer(channel, root)
if err := server.Serve(); err == io.EOF {
server.Close()
log.Print("sftp client exited session.")
} else if err != nil {
log.Fatal("sftp server completed with error:", err)
}
}
}

View File

@ -6,6 +6,7 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
@ -136,7 +137,10 @@ func main() {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
if err := server.Serve(); err != nil { if err := server.Serve(); err == io.EOF {
server.Close()
log.Print("sftp client exited session.")
} else if err != nil {
log.Fatal("sftp server completed with error:", err) log.Fatal("sftp server completed with error:", err)
} }
} }

View File

@ -43,6 +43,7 @@ func main() {
config := ssh.ClientConfig{ config := ssh.ClientConfig{
User: *USER, User: *USER,
Auth: auths, Auth: auths,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
} }
addr := fmt.Sprintf("%s:%d", *HOST, *PORT) addr := fmt.Sprintf("%s:%d", *HOST, *PORT)
conn, err := ssh.Dial("tcp", addr, &config) conn, err := ssh.Dial("tcp", addr, &config)

View File

@ -43,6 +43,7 @@ func main() {
config := ssh.ClientConfig{ config := ssh.ClientConfig{
User: *USER, User: *USER,
Auth: auths, Auth: auths,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
} }
addr := fmt.Sprintf("%s:%d", *HOST, *PORT) addr := fmt.Sprintf("%s:%d", *HOST, *PORT)
conn, err := ssh.Dial("tcp", addr, &config) conn, err := ssh.Dial("tcp", addr, &config)

345
vendor/src/github.com/pkg/sftp/match.go vendored Normal file
View File

@ -0,0 +1,345 @@
package sftp
import (
"errors"
"strings"
"unicode/utf8"
)
// ErrBadPattern indicates a globbing pattern was malformed.
var ErrBadPattern = errors.New("syntax error in pattern")
// Unix separator
const separator = "/"
// Match reports whether name matches the shell file name pattern.
// The pattern syntax is:
//
// pattern:
// { term }
// term:
// '*' matches any sequence of non-Separator characters
// '?' matches any single non-Separator character
// '[' [ '^' ] { character-range } ']'
// character class (must be non-empty)
// c matches character c (c != '*', '?', '\\', '[')
// '\\' c matches character c
//
// character-range:
// c matches character c (c != '\\', '-', ']')
// '\\' c matches character c
// lo '-' hi matches character c for lo <= c <= hi
//
// Match requires pattern to match all of name, not just a substring.
// The only possible returned error is ErrBadPattern, when pattern
// is malformed.
//
//
func Match(pattern, name string) (matched bool, err error) {
Pattern:
for len(pattern) > 0 {
var star bool
var chunk string
star, chunk, pattern = scanChunk(pattern)
if star && chunk == "" {
// Trailing * matches rest of string unless it has a /.
return !strings.Contains(name, separator), nil
}
// Look for match at current position.
t, ok, err := matchChunk(chunk, name)
// if we're the last chunk, make sure we've exhausted the name
// otherwise we'll give a false result even if we could still match
// using the star
if ok && (len(t) == 0 || len(pattern) > 0) {
name = t
continue
}
if err != nil {
return false, err
}
if star {
// Look for match skipping i+1 bytes.
// Cannot skip /.
for i := 0; i < len(name) && !isPathSeparator(name[i]); i++ {
t, ok, err := matchChunk(chunk, name[i+1:])
if ok {
// if we're the last chunk, make sure we exhausted the name
if len(pattern) == 0 && len(t) > 0 {
continue
}
name = t
continue Pattern
}
if err != nil {
return false, err
}
}
}
return false, nil
}
return len(name) == 0, nil
}
// detect if byte(char) is path separator
func isPathSeparator(c byte) bool {
return string(c) == "/"
}
// scanChunk gets the next segment of pattern, which is a non-star string
// possibly preceded by a star.
func scanChunk(pattern string) (star bool, chunk, rest string) {
for len(pattern) > 0 && pattern[0] == '*' {
pattern = pattern[1:]
star = true
}
inrange := false
var i int
Scan:
for i = 0; i < len(pattern); i++ {
switch pattern[i] {
case '\\':
// error check handled in matchChunk: bad pattern.
if i+1 < len(pattern) {
i++
}
case '[':
inrange = true
case ']':
inrange = false
case '*':
if !inrange {
break Scan
}
}
}
return star, pattern[0:i], pattern[i:]
}
// matchChunk checks whether chunk matches the beginning of s.
// If so, it returns the remainder of s (after the match).
// Chunk is all single-character operators: literals, char classes, and ?.
func matchChunk(chunk, s string) (rest string, ok bool, err error) {
for len(chunk) > 0 {
if len(s) == 0 {
return
}
switch chunk[0] {
case '[':
// character class
r, n := utf8.DecodeRuneInString(s)
s = s[n:]
chunk = chunk[1:]
// We can't end right after '[', we're expecting at least
// a closing bracket and possibly a caret.
if len(chunk) == 0 {
err = ErrBadPattern
return
}
// possibly negated
negated := chunk[0] == '^'
if negated {
chunk = chunk[1:]
}
// parse all ranges
match := false
nrange := 0
for {
if len(chunk) > 0 && chunk[0] == ']' && nrange > 0 {
chunk = chunk[1:]
break
}
var lo, hi rune
if lo, chunk, err = getEsc(chunk); err != nil {
return
}
hi = lo
if chunk[0] == '-' {
if hi, chunk, err = getEsc(chunk[1:]); err != nil {
return
}
}
if lo <= r && r <= hi {
match = true
}
nrange++
}
if match == negated {
return
}
case '?':
if isPathSeparator(s[0]) {
return
}
_, n := utf8.DecodeRuneInString(s)
s = s[n:]
chunk = chunk[1:]
case '\\':
chunk = chunk[1:]
if len(chunk) == 0 {
err = ErrBadPattern
return
}
fallthrough
default:
if chunk[0] != s[0] {
return
}
s = s[1:]
chunk = chunk[1:]
}
}
return s, true, nil
}
// getEsc gets a possibly-escaped character from chunk, for a character class.
func getEsc(chunk string) (r rune, nchunk string, err error) {
if len(chunk) == 0 || chunk[0] == '-' || chunk[0] == ']' {
err = ErrBadPattern
return
}
if chunk[0] == '\\' {
chunk = chunk[1:]
if len(chunk) == 0 {
err = ErrBadPattern
return
}
}
r, n := utf8.DecodeRuneInString(chunk)
if r == utf8.RuneError && n == 1 {
err = ErrBadPattern
}
nchunk = chunk[n:]
if len(nchunk) == 0 {
err = ErrBadPattern
}
return
}
// Split splits path immediately following the final Separator,
// separating it into a directory and file name component.
// If there is no Separator in path, Split returns an empty dir
// and file set to path.
// The returned values have the property that path = dir+file.
func Split(path string) (dir, file string) {
i := len(path) - 1
for i >= 0 && !isPathSeparator(path[i]) {
i--
}
return path[:i+1], path[i+1:]
}
// Glob returns the names of all files matching pattern or nil
// if there is no matching file. The syntax of patterns is the same
// as in Match. The pattern may describe hierarchical names such as
// /usr/*/bin/ed (assuming the Separator is '/').
//
// Glob ignores file system errors such as I/O errors reading directories.
// The only possible returned error is ErrBadPattern, when pattern
// is malformed.
func (c *Client) Glob(pattern string) (matches []string, err error) {
if !hasMeta(pattern) {
file, err := c.Lstat(pattern)
if err != nil {
return nil, nil
}
dir, _ := Split(pattern)
dir = cleanGlobPath(dir)
return []string{Join(dir, file.Name())}, nil
}
dir, file := Split(pattern)
dir = cleanGlobPath(dir)
if !hasMeta(dir) {
return c.glob(dir, file, nil)
}
// Prevent infinite recursion. See issue 15879.
if dir == pattern {
return nil, ErrBadPattern
}
var m []string
m, err = c.Glob(dir)
if err != nil {
return
}
for _, d := range m {
matches, err = c.glob(d, file, matches)
if err != nil {
return
}
}
return
}
// cleanGlobPath prepares path for glob matching.
func cleanGlobPath(path string) string {
switch path {
case "":
return "."
case string(separator):
// do nothing to the path
return path
default:
return path[0 : len(path)-1] // chop off trailing separator
}
}
// glob searches for files matching pattern in the directory dir
// and appends them to matches. If the directory cannot be
// opened, it returns the existing matches. New matches are
// added in lexicographical order.
func (c *Client) glob(dir, pattern string, matches []string) (m []string, e error) {
m = matches
fi, err := c.Stat(dir)
if err != nil {
return
}
if !fi.IsDir() {
return
}
names, err := c.ReadDir(dir)
if err != nil {
return
}
//sort.Strings(names)
for _, n := range names {
matched, err := Match(pattern, n.Name())
if err != nil {
return m, err
}
if matched {
m = append(m, Join(dir, n.Name()))
}
}
return
}
// Join joins any number of path elements into a single path, adding
// a Separator if necessary.
// all empty strings are ignored.
func Join(elem ...string) string {
return join(elem)
}
func join(elem []string) string {
// If there's a bug here, fix the logic in ./path_plan9.go too.
for i, e := range elem {
if e != "" {
return strings.Join(elem[i:], string(separator))
}
}
return ""
}
// hasMeta reports whether path contains any of the magic characters
// recognized by Match.
func hasMeta(path string) bool {
// TODO(niemeyer): Should other magic characters be added here?
return strings.ContainsAny(path, "*?[")
}

View File

@ -0,0 +1,156 @@
package sftp
import (
"encoding"
"sync"
)
// The goal of the packetManager is to keep the outgoing packets in the same
// order as the incoming. This is due to some sftp clients requiring this
// behavior (eg. winscp).
type packetSender interface {
sendPacket(encoding.BinaryMarshaler) error
}
type packetManager struct {
requests chan requestPacket
responses chan responsePacket
fini chan struct{}
incoming requestPacketIDs
outgoing responsePackets
sender packetSender // connection object
working *sync.WaitGroup
}
func newPktMgr(sender packetSender) packetManager {
s := packetManager{
requests: make(chan requestPacket, sftpServerWorkerCount),
responses: make(chan responsePacket, sftpServerWorkerCount),
fini: make(chan struct{}),
incoming: make([]uint32, 0, sftpServerWorkerCount),
outgoing: make([]responsePacket, 0, sftpServerWorkerCount),
sender: sender,
working: &sync.WaitGroup{},
}
go s.controller()
return s
}
// register incoming packets to be handled
// send id of 0 for packets without id
func (s packetManager) incomingPacket(pkt requestPacket) {
s.working.Add(1)
s.requests <- pkt // buffer == sftpServerWorkerCount
}
// register outgoing packets as being ready
func (s packetManager) readyPacket(pkt responsePacket) {
s.responses <- pkt
s.working.Done()
}
// shut down packetManager controller
func (s packetManager) close() {
// pause until current packets are processed
s.working.Wait()
close(s.fini)
}
// Passed a worker function, returns a channel for incoming packets.
// The goal is to process packets in the order they are received as is
// requires by section 7 of the RFC, while maximizing throughput of file
// transfers.
func (s *packetManager) workerChan(runWorker func(requestChan)) requestChan {
rwChan := make(chan requestPacket, sftpServerWorkerCount)
for i := 0; i < sftpServerWorkerCount; i++ {
runWorker(rwChan)
}
cmdChan := make(chan requestPacket)
runWorker(cmdChan)
pktChan := make(chan requestPacket, sftpServerWorkerCount)
go func() {
// start with cmdChan
curChan := cmdChan
for pkt := range pktChan {
// on file open packet, switch to rwChan
switch pkt.(type) {
case *sshFxpOpenPacket:
curChan = rwChan
// on file close packet, switch back to cmdChan
// after waiting for any reads/writes to finish
case *sshFxpClosePacket:
// wait for rwChan to finish
s.working.Wait()
// stop using rwChan
curChan = cmdChan
}
s.incomingPacket(pkt)
curChan <- pkt
}
close(rwChan)
close(cmdChan)
s.close()
}()
return pktChan
}
// process packets
func (s *packetManager) controller() {
for {
select {
case pkt := <-s.requests:
debug("incoming id: %v", pkt.id())
s.incoming = append(s.incoming, pkt.id())
if len(s.incoming) > 1 {
s.incoming.Sort()
}
case pkt := <-s.responses:
debug("outgoing pkt: %v", pkt.id())
s.outgoing = append(s.outgoing, pkt)
if len(s.outgoing) > 1 {
s.outgoing.Sort()
}
case <-s.fini:
return
}
s.maybeSendPackets()
}
}
// send as many packets as are ready
func (s *packetManager) maybeSendPackets() {
for {
if len(s.outgoing) == 0 || len(s.incoming) == 0 {
debug("break! -- outgoing: %v; incoming: %v",
len(s.outgoing), len(s.incoming))
break
}
out := s.outgoing[0]
in := s.incoming[0]
// debug("incoming: %v", s.incoming)
// debug("outgoing: %v", outfilter(s.outgoing))
if in == out.id() {
s.sender.sendPacket(out)
// pop off heads
copy(s.incoming, s.incoming[1:]) // shift left
s.incoming = s.incoming[:len(s.incoming)-1] // remove last
copy(s.outgoing, s.outgoing[1:]) // shift left
s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
} else {
break
}
}
}
func outfilter(o []responsePacket) []uint32 {
res := make([]uint32, 0, len(o))
for _, v := range o {
res = append(res, v.id())
}
return res
}

View File

@ -0,0 +1,21 @@
// +build go1.8
package sftp
import "sort"
type responsePackets []responsePacket
func (r responsePackets) Sort() {
sort.Slice(r, func(i, j int) bool {
return r[i].id() < r[j].id()
})
}
type requestPacketIDs []uint32
func (r requestPacketIDs) Sort() {
sort.Slice(r, func(i, j int) bool {
return r[i] < r[j]
})
}

View File

@ -0,0 +1,21 @@
// +build !go1.8
package sftp
import "sort"
// for sorting/ordering outgoing
type responsePackets []responsePacket
func (r responsePackets) Len() int { return len(r) }
func (r responsePackets) Swap(i, j int) { r[i], r[j] = r[j], r[i] }
func (r responsePackets) Less(i, j int) bool { return r[i].id() < r[j].id() }
func (r responsePackets) Sort() { sort.Sort(r) }
// for sorting/ordering incoming
type requestPacketIDs []uint32
func (r requestPacketIDs) Len() int { return len(r) }
func (r requestPacketIDs) Swap(i, j int) { r[i], r[j] = r[j], r[i] }
func (r requestPacketIDs) Less(i, j int) bool { return r[i] < r[j] }
func (r requestPacketIDs) Sort() { sort.Sort(r) }

View File

@ -0,0 +1,140 @@
package sftp
import (
"encoding"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type _testSender struct {
sent chan encoding.BinaryMarshaler
}
func newTestSender() *_testSender {
return &_testSender{make(chan encoding.BinaryMarshaler)}
}
func (s _testSender) sendPacket(p encoding.BinaryMarshaler) error {
s.sent <- p
return nil
}
type fakepacket uint32
func (fakepacket) MarshalBinary() ([]byte, error) {
return []byte{}, nil
}
func (fakepacket) UnmarshalBinary([]byte) error {
return nil
}
func (f fakepacket) id() uint32 {
return uint32(f)
}
type pair struct {
in fakepacket
out fakepacket
}
// basic test
var ttable1 = []pair{
pair{fakepacket(0), fakepacket(0)},
pair{fakepacket(1), fakepacket(1)},
pair{fakepacket(2), fakepacket(2)},
pair{fakepacket(3), fakepacket(3)},
}
// outgoing packets out of order
var ttable2 = []pair{
pair{fakepacket(0), fakepacket(0)},
pair{fakepacket(1), fakepacket(4)},
pair{fakepacket(2), fakepacket(1)},
pair{fakepacket(3), fakepacket(3)},
pair{fakepacket(4), fakepacket(2)},
}
// incoming packets out of order
var ttable3 = []pair{
pair{fakepacket(2), fakepacket(0)},
pair{fakepacket(1), fakepacket(1)},
pair{fakepacket(3), fakepacket(2)},
pair{fakepacket(0), fakepacket(3)},
}
var tables = [][]pair{ttable1, ttable2, ttable3}
func TestPacketManager(t *testing.T) {
sender := newTestSender()
s := newPktMgr(sender)
for i := range tables {
table := tables[i]
for _, p := range table {
s.incomingPacket(p.in)
}
for _, p := range table {
s.readyPacket(p.out)
}
for i := 0; i < len(table); i++ {
pkt := <-sender.sent
id := pkt.(fakepacket).id()
assert.Equal(t, id, uint32(i))
}
}
s.close()
}
// Test what happens when the pool processes a close packet on a file that it
// is still reading from.
func TestCloseOutOfOrder(t *testing.T) {
packets := []requestPacket{
&sshFxpRemovePacket{ID: 0, Filename: "foo"},
&sshFxpOpenPacket{ID: 1},
&sshFxpWritePacket{ID: 2, Handle: "foo"},
&sshFxpWritePacket{ID: 3, Handle: "foo"},
&sshFxpWritePacket{ID: 4, Handle: "foo"},
&sshFxpWritePacket{ID: 5, Handle: "foo"},
&sshFxpClosePacket{ID: 6, Handle: "foo"},
&sshFxpRemovePacket{ID: 7, Filename: "foo"},
}
recvChan := make(chan requestPacket, len(packets)+1)
sender := newTestSender()
pktMgr := newPktMgr(sender)
wg := sync.WaitGroup{}
wg.Add(len(packets))
runWorker := func(ch requestChan) {
go func() {
for pkt := range ch {
if _, ok := pkt.(*sshFxpWritePacket); ok {
// sleep to cause writes to come after close/remove
time.Sleep(time.Millisecond)
}
pktMgr.working.Done()
recvChan <- pkt
wg.Done()
}
}()
}
pktChan := pktMgr.workerChan(runWorker)
for _, p := range packets {
pktChan <- p
}
wg.Wait()
close(recvChan)
received := []requestPacket{}
for p := range recvChan {
received = append(received, p)
}
if received[len(received)-2].id() != packets[len(packets)-2].id() {
t.Fatal("Packets processed out of order1:", received, packets)
}
if received[len(received)-1].id() != packets[len(packets)-1].id() {
t.Fatal("Packets processed out of order2:", received, packets)
}
}

View File

@ -0,0 +1,141 @@
package sftp
import (
"encoding"
"github.com/pkg/errors"
)
// all incoming packets
type requestPacket interface {
encoding.BinaryUnmarshaler
id() uint32
}
type requestChan chan requestPacket
type responsePacket interface {
encoding.BinaryMarshaler
id() uint32
}
// interfaces to group types
type hasPath interface {
requestPacket
getPath() string
}
type hasHandle interface {
requestPacket
getHandle() string
}
type isOpener interface {
hasPath
isOpener()
}
type notReadOnly interface {
notReadOnly()
}
//// define types by adding methods
// hasPath
func (p sshFxpLstatPacket) getPath() string { return p.Path }
func (p sshFxpStatPacket) getPath() string { return p.Path }
func (p sshFxpRmdirPacket) getPath() string { return p.Path }
func (p sshFxpReadlinkPacket) getPath() string { return p.Path }
func (p sshFxpRealpathPacket) getPath() string { return p.Path }
func (p sshFxpMkdirPacket) getPath() string { return p.Path }
func (p sshFxpSetstatPacket) getPath() string { return p.Path }
func (p sshFxpStatvfsPacket) getPath() string { return p.Path }
func (p sshFxpRemovePacket) getPath() string { return p.Filename }
func (p sshFxpRenamePacket) getPath() string { return p.Oldpath }
func (p sshFxpSymlinkPacket) getPath() string { return p.Targetpath }
// Openers implement hasPath and isOpener
func (p sshFxpOpendirPacket) getPath() string { return p.Path }
func (p sshFxpOpendirPacket) isOpener() {}
func (p sshFxpOpenPacket) getPath() string { return p.Path }
func (p sshFxpOpenPacket) isOpener() {}
// hasHandle
func (p sshFxpFstatPacket) getHandle() string { return p.Handle }
func (p sshFxpFsetstatPacket) getHandle() string { return p.Handle }
func (p sshFxpReadPacket) getHandle() string { return p.Handle }
func (p sshFxpWritePacket) getHandle() string { return p.Handle }
func (p sshFxpReaddirPacket) getHandle() string { return p.Handle }
// notReadOnly
func (p sshFxpWritePacket) notReadOnly() {}
func (p sshFxpSetstatPacket) notReadOnly() {}
func (p sshFxpFsetstatPacket) notReadOnly() {}
func (p sshFxpRemovePacket) notReadOnly() {}
func (p sshFxpMkdirPacket) notReadOnly() {}
func (p sshFxpRmdirPacket) notReadOnly() {}
func (p sshFxpRenamePacket) notReadOnly() {}
func (p sshFxpSymlinkPacket) notReadOnly() {}
// this has a handle, but is only used for close
func (p sshFxpClosePacket) getHandle() string { return p.Handle }
// some packets with ID are missing id()
func (p sshFxpDataPacket) id() uint32 { return p.ID }
func (p sshFxpStatusPacket) id() uint32 { return p.ID }
func (p sshFxpStatResponse) id() uint32 { return p.ID }
func (p sshFxpNamePacket) id() uint32 { return p.ID }
func (p sshFxpHandlePacket) id() uint32 { return p.ID }
func (p sshFxVersionPacket) id() uint32 { return 0 }
// take raw incoming packet data and build packet objects
func makePacket(p rxPacket) (requestPacket, error) {
var pkt requestPacket
switch p.pktType {
case ssh_FXP_INIT:
pkt = &sshFxInitPacket{}
case ssh_FXP_LSTAT:
pkt = &sshFxpLstatPacket{}
case ssh_FXP_OPEN:
pkt = &sshFxpOpenPacket{}
case ssh_FXP_CLOSE:
pkt = &sshFxpClosePacket{}
case ssh_FXP_READ:
pkt = &sshFxpReadPacket{}
case ssh_FXP_WRITE:
pkt = &sshFxpWritePacket{}
case ssh_FXP_FSTAT:
pkt = &sshFxpFstatPacket{}
case ssh_FXP_SETSTAT:
pkt = &sshFxpSetstatPacket{}
case ssh_FXP_FSETSTAT:
pkt = &sshFxpFsetstatPacket{}
case ssh_FXP_OPENDIR:
pkt = &sshFxpOpendirPacket{}
case ssh_FXP_READDIR:
pkt = &sshFxpReaddirPacket{}
case ssh_FXP_REMOVE:
pkt = &sshFxpRemovePacket{}
case ssh_FXP_MKDIR:
pkt = &sshFxpMkdirPacket{}
case ssh_FXP_RMDIR:
pkt = &sshFxpRmdirPacket{}
case ssh_FXP_REALPATH:
pkt = &sshFxpRealpathPacket{}
case ssh_FXP_STAT:
pkt = &sshFxpStatPacket{}
case ssh_FXP_RENAME:
pkt = &sshFxpRenamePacket{}
case ssh_FXP_READLINK:
pkt = &sshFxpReadlinkPacket{}
case ssh_FXP_SYMLINK:
pkt = &sshFxpSymlinkPacket{}
case ssh_FXP_EXTENDED:
pkt = &sshFxpExtendedPacket{}
default:
return nil, errors.Errorf("unhandled packet type: %s", p.pktType)
}
if err := pkt.UnmarshalBinary(p.pktBytes); err != nil {
return nil, err
}
return pkt, nil
}

View File

@ -170,9 +170,6 @@ func unmarshalExtensionPair(b []byte) (extensionPair, []byte, error) {
return ep, b, err return ep, b, err
} }
ep.Data, b, err = unmarshalStringSafe(b) ep.Data, b, err = unmarshalStringSafe(b)
if err != nil {
return ep, b, err
}
return ep, b, err return ep, b, err
} }

View File

@ -0,0 +1,244 @@
package sftp
// This serves as an example of how to implement the request server handler as
// well as a dummy backend for testing. It implements an in-memory backend that
// works as a very simple filesystem with simple flat key-value lookup system.
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strconv"
"sync"
"time"
)
// InMemHandler returns a Hanlders object with the test handlers
func InMemHandler() Handlers {
root := &root{
files: make(map[string]*memFile),
}
root.memFile = newMemFile("/", true)
return Handlers{root, root, root, root}
}
// Handlers
func (fs *root) Fileread(r Request) (io.ReaderAt, error) {
fs.filesLock.Lock()
defer fs.filesLock.Unlock()
file, err := fs.fetch(r.Filepath)
if err != nil {
return nil, err
}
if file.symlink != "" {
file, err = fs.fetch(file.symlink)
if err != nil {
return nil, err
}
}
return file.ReaderAt()
}
func (fs *root) Filewrite(r Request) (io.WriterAt, error) {
fs.filesLock.Lock()
defer fs.filesLock.Unlock()
file, err := fs.fetch(r.Filepath)
if err == os.ErrNotExist {
dir, err := fs.fetch(filepath.Dir(r.Filepath))
if err != nil {
return nil, err
}
if !dir.isdir {
return nil, os.ErrInvalid
}
file = newMemFile(r.Filepath, false)
fs.files[r.Filepath] = file
}
return file.WriterAt()
}
func (fs *root) Filecmd(r Request) error {
fs.filesLock.Lock()
defer fs.filesLock.Unlock()
switch r.Method {
case "Setstat":
return nil
case "Rename":
file, err := fs.fetch(r.Filepath)
if err != nil {
return err
}
if _, ok := fs.files[r.Target]; ok {
return &os.LinkError{Op: "rename", Old: r.Filepath, New: r.Target,
Err: fmt.Errorf("dest file exists")}
}
fs.files[r.Target] = file
delete(fs.files, r.Filepath)
case "Rmdir", "Remove":
_, err := fs.fetch(filepath.Dir(r.Filepath))
if err != nil {
return err
}
delete(fs.files, r.Filepath)
case "Mkdir":
_, err := fs.fetch(filepath.Dir(r.Filepath))
if err != nil {
return err
}
fs.files[r.Filepath] = newMemFile(r.Filepath, true)
case "Symlink":
_, err := fs.fetch(r.Filepath)
if err != nil {
return err
}
link := newMemFile(r.Target, false)
link.symlink = r.Filepath
fs.files[r.Target] = link
}
return nil
}
func (fs *root) Fileinfo(r Request) ([]os.FileInfo, error) {
fs.filesLock.Lock()
defer fs.filesLock.Unlock()
switch r.Method {
case "List":
var err error
batch_size := 10
current_offset := 0
if token := r.LsNext(); token != "" {
current_offset, err = strconv.Atoi(token)
if err != nil {
return nil, os.ErrInvalid
}
}
ordered_names := []string{}
for fn, _ := range fs.files {
if filepath.Dir(fn) == r.Filepath {
ordered_names = append(ordered_names, fn)
}
}
sort.Sort(sort.StringSlice(ordered_names))
list := make([]os.FileInfo, len(ordered_names))
for i, fn := range ordered_names {
list[i] = fs.files[fn]
}
if len(list) < current_offset {
return nil, io.EOF
}
new_offset := current_offset + batch_size
if new_offset > len(list) {
new_offset = len(list)
}
r.LsSave(strconv.Itoa(new_offset))
return list[current_offset:new_offset], nil
case "Stat":
file, err := fs.fetch(r.Filepath)
if err != nil {
return nil, err
}
return []os.FileInfo{file}, nil
case "Readlink":
file, err := fs.fetch(r.Filepath)
if err != nil {
return nil, err
}
if file.symlink != "" {
file, err = fs.fetch(file.symlink)
if err != nil {
return nil, err
}
}
return []os.FileInfo{file}, nil
}
return nil, nil
}
// In memory file-system-y thing that the Hanlders live on
type root struct {
*memFile
files map[string]*memFile
filesLock sync.Mutex
}
func (fs *root) fetch(path string) (*memFile, error) {
if path == "/" {
return fs.memFile, nil
}
if file, ok := fs.files[path]; ok {
return file, nil
}
return nil, os.ErrNotExist
}
// Implements os.FileInfo, Reader and Writer interfaces.
// These are the 3 interfaces necessary for the Handlers.
type memFile struct {
name string
modtime time.Time
symlink string
isdir bool
content []byte
contentLock sync.RWMutex
}
// factory to make sure modtime is set
func newMemFile(name string, isdir bool) *memFile {
return &memFile{
name: name,
modtime: time.Now(),
isdir: isdir,
}
}
// Have memFile fulfill os.FileInfo interface
func (f *memFile) Name() string { return filepath.Base(f.name) }
func (f *memFile) Size() int64 { return int64(len(f.content)) }
func (f *memFile) Mode() os.FileMode {
ret := os.FileMode(0644)
if f.isdir {
ret = os.FileMode(0755) | os.ModeDir
}
if f.symlink != "" {
ret = os.FileMode(0777) | os.ModeSymlink
}
return ret
}
func (f *memFile) ModTime() time.Time { return f.modtime }
func (f *memFile) IsDir() bool { return f.isdir }
func (f *memFile) Sys() interface{} {
return fakeFileInfoSys()
}
// Read/Write
func (f *memFile) ReaderAt() (io.ReaderAt, error) {
if f.isdir {
return nil, os.ErrInvalid
}
return bytes.NewReader(f.content), nil
}
func (f *memFile) WriterAt() (io.WriterAt, error) {
if f.isdir {
return nil, os.ErrInvalid
}
return f, nil
}
func (f *memFile) WriteAt(p []byte, off int64) (int, error) {
// fmt.Println(string(p), off)
// mimic write delays, should be optional
time.Sleep(time.Microsecond * time.Duration(len(p)))
f.contentLock.Lock()
defer f.contentLock.Unlock()
plen := len(p) + int(off)
if plen >= len(f.content) {
nc := make([]byte, plen)
copy(nc, f.content)
f.content = nc
}
copy(f.content[off:], p)
return len(p), nil
}

View File

@ -0,0 +1,30 @@
package sftp
import (
"io"
"os"
)
// Interfaces are differentiated based on required returned values.
// All input arguments are to be pulled from Request (the only arg).
// FileReader should return an io.Reader for the filepath
type FileReader interface {
Fileread(Request) (io.ReaderAt, error)
}
// FileWriter should return an io.Writer for the filepath
type FileWriter interface {
Filewrite(Request) (io.WriterAt, error)
}
// FileCmder should return an error (rename, remove, setstate, etc.)
type FileCmder interface {
Filecmd(Request) error
}
// FileInfoer should return file listing info and errors (readdir, stat)
// note stat requests would return a list of 1
type FileInfoer interface {
Fileinfo(Request) ([]os.FileInfo, error)
}

View File

@ -0,0 +1,48 @@
# Request Based SFTP API
The request based API allows for custom backends in a way similar to the http
package. In order to create a backend you need to implement 4 handler
interfaces; one for reading, one for writing, one for misc commands and one for
listing files. Each has 1 required method and in each case those methods take
the Request as the only parameter and they each return something different.
These 4 interfaces are enough to handle all the SFTP traffic in a simplified
manner.
The Request structure has 5 public fields which you will deal with.
- Method (string) - string name of incoming call
- Filepath (string) - path of file to act on
- Attrs ([]byte) - byte string of file attribute data
- Target (string) - target path for renames and sym-links
Below are the methods and a brief description of what they need to do.
### Fileread(*Request) (io.Reader, error)
Handler for "Get" method and returns an io.Reader for the file which the server
then sends to the client.
### Filewrite(*Request) (io.Writer, error)
Handler for "Put" method and returns an io.Writer for the file which the server
then writes the uploaded file to.
### Filecmd(*Request) error
Handles "SetStat", "Rename", "Rmdir", "Mkdir" and "Symlink" methods. Makes the
appropriate changes and returns nil for success or an filesystem like error
(eg. os.ErrNotExist).
### Fileinfo(*Request) ([]os.FileInfo, error)
Handles "List", "Stat", "Readlink" methods. Gathers/creates FileInfo structs
with the data on the files and returns in a list (list of 1 for Stat and
Readlink).
## TODO
- Add support for API users to see trace/debugging info of what is going on
inside SFTP server.
- Consider adding support for SFTP file append only mode.

View File

@ -0,0 +1,231 @@
package sftp
import (
"encoding"
"io"
"os"
"path/filepath"
"strconv"
"sync"
"syscall"
"github.com/pkg/errors"
)
var maxTxPacket uint32 = 1 << 15
type handleHandler func(string) string
// Handlers contains the 4 SFTP server request handlers.
type Handlers struct {
FileGet FileReader
FilePut FileWriter
FileCmd FileCmder
FileInfo FileInfoer
}
// RequestServer abstracts the sftp protocol with an http request-like protocol
type RequestServer struct {
*serverConn
Handlers Handlers
pktMgr packetManager
openRequests map[string]Request
openRequestLock sync.RWMutex
handleCount int
}
// NewRequestServer creates/allocates/returns new RequestServer.
// Normally there there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers) *RequestServer {
svrConn := &serverConn{
conn: conn{
Reader: rwc,
WriteCloser: rwc,
},
}
return &RequestServer{
serverConn: svrConn,
Handlers: h,
pktMgr: newPktMgr(svrConn),
openRequests: make(map[string]Request),
}
}
func (rs *RequestServer) nextRequest(r Request) string {
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
rs.handleCount++
handle := strconv.Itoa(rs.handleCount)
rs.openRequests[handle] = r
return handle
}
func (rs *RequestServer) getRequest(handle string) (Request, bool) {
rs.openRequestLock.RLock()
defer rs.openRequestLock.RUnlock()
r, ok := rs.openRequests[handle]
return r, ok
}
func (rs *RequestServer) closeRequest(handle string) {
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
if r, ok := rs.openRequests[handle]; ok {
r.close()
delete(rs.openRequests, handle)
}
}
// Close the read/write/closer to trigger exiting the main server loop
func (rs *RequestServer) Close() error { return rs.conn.Close() }
// Serve requests for user session
func (rs *RequestServer) Serve() error {
var wg sync.WaitGroup
runWorker := func(ch requestChan) {
wg.Add(1)
go func() {
defer wg.Done()
if err := rs.packetWorker(ch); err != nil {
rs.conn.Close() // shuts down recvPacket
}
}()
}
pktChan := rs.pktMgr.workerChan(runWorker)
var err error
var pkt requestPacket
var pktType uint8
var pktBytes []byte
for {
pktType, pktBytes, err = rs.recvPacket()
if err != nil {
break
}
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
if err != nil {
debug("makePacket err: %v", err)
rs.conn.Close() // shuts down recvPacket
break
}
pktChan <- pkt
}
close(pktChan) // shuts down sftpServerWorkers
wg.Wait() // wait for all workers to exit
return err
}
func (rs *RequestServer) packetWorker(pktChan chan requestPacket) error {
for pkt := range pktChan {
var rpkt responsePacket
switch pkt := pkt.(type) {
case *sshFxInitPacket:
rpkt = sshFxVersionPacket{sftpProtocolVersion, nil}
case *sshFxpClosePacket:
handle := pkt.getHandle()
rs.closeRequest(handle)
rpkt = statusFromError(pkt, nil)
case *sshFxpRealpathPacket:
rpkt = cleanPath(pkt)
case isOpener:
handle := rs.nextRequest(requestFromPacket(pkt))
rpkt = sshFxpHandlePacket{pkt.id(), handle}
case *sshFxpFstatPacket:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
request = requestFromPacket(
&sshFxpStatPacket{ID: pkt.id(), Path: request.Filepath})
rpkt = rs.handle(request, pkt)
}
case *sshFxpFsetstatPacket:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
request = requestFromPacket(
&sshFxpSetstatPacket{ID: pkt.id(), Path: request.Filepath,
Flags: pkt.Flags, Attrs: pkt.Attrs,
})
rpkt = rs.handle(request, pkt)
}
case hasHandle:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
request.update(pkt)
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
rpkt = rs.handle(request, pkt)
}
case hasPath:
request := requestFromPacket(pkt)
rpkt = rs.handle(request, pkt)
default:
return errors.Errorf("unexpected packet type %T", pkt)
}
err := rs.sendPacket(rpkt)
if err != nil {
return err
}
}
return nil
}
func cleanPath(pkt *sshFxpRealpathPacket) responsePacket {
path := pkt.getPath()
if !filepath.IsAbs(path) {
path = "/" + path
} // all paths are absolute
cleaned_path := filepath.Clean(path)
return &sshFxpNamePacket{
ID: pkt.id(),
NameAttrs: []sshFxpNameAttr{{
Name: cleaned_path,
LongName: cleaned_path,
Attrs: emptyFileStat,
}},
}
}
func (rs *RequestServer) handle(request Request, pkt requestPacket) responsePacket {
// fmt.Println("Request Method: ", request.Method)
rpkt, err := request.handle(rs.Handlers)
if err != nil {
err = errorAdapter(err)
rpkt = statusFromError(pkt, err)
}
return rpkt
}
// Wrap underlying connection methods to use packetManager
func (rs *RequestServer) sendPacket(m encoding.BinaryMarshaler) error {
if pkt, ok := m.(responsePacket); ok {
rs.pktMgr.readyPacket(pkt)
} else {
return errors.Errorf("unexpected packet type %T", m)
}
return nil
}
func (rs *RequestServer) sendError(p ider, err error) error {
return rs.sendPacket(statusFromError(p, err))
}
// os.ErrNotExist should convert to ssh_FX_NO_SUCH_FILE, but is not recognized
// by statusFromError. So we convert to syscall.ENOENT which it does.
func errorAdapter(err error) error {
if err == os.ErrNotExist {
return syscall.ENOENT
}
return err
}

View File

@ -0,0 +1,329 @@
package sftp
import (
"fmt"
"io"
"net"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
var _ = fmt.Print
type csPair struct {
cli *Client
svr *RequestServer
}
// these must be closed in order, else client.Close will hang
func (cs csPair) Close() {
cs.svr.Close()
cs.cli.Close()
os.Remove(sock)
}
func (cs csPair) testHandler() *root {
return cs.svr.Handlers.FileGet.(*root)
}
const sock = "/tmp/rstest.sock"
func clientRequestServerPair(t *testing.T) *csPair {
ready := make(chan bool)
os.Remove(sock) // either this or signal handling
var server *RequestServer
go func() {
l, err := net.Listen("unix", sock)
if err != nil {
// neither assert nor t.Fatal reliably exit before Accept errors
panic(err)
}
ready <- true
fd, err := l.Accept()
assert.Nil(t, err)
handlers := InMemHandler()
server = NewRequestServer(fd, handlers)
server.Serve()
}()
<-ready
defer os.Remove(sock)
c, err := net.Dial("unix", sock)
assert.Nil(t, err)
client, err := NewClientPipe(c, c)
if err != nil {
t.Fatalf("%+v\n", err)
}
return &csPair{client, server}
}
// after adding logging, maybe check log to make sure packet handling
// was split over more than one worker
func TestRequestSplitWrite(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
w, err := p.cli.Create("/foo")
assert.Nil(t, err)
p.cli.maxPacket = 3 // force it to send in small chunks
contents := "one two three four five six seven eight nine ten"
w.Write([]byte(contents))
w.Close()
r := p.testHandler()
f, _ := r.fetch("/foo")
assert.Equal(t, contents, string(f.content))
}
func TestRequestCache(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
foo := NewRequest("", "foo")
bar := NewRequest("", "bar")
fh := p.svr.nextRequest(foo)
bh := p.svr.nextRequest(bar)
assert.Len(t, p.svr.openRequests, 2)
_foo, ok := p.svr.getRequest(fh)
assert.Equal(t, foo, _foo)
assert.True(t, ok)
_, ok = p.svr.getRequest("zed")
assert.False(t, ok)
p.svr.closeRequest(fh)
p.svr.closeRequest(bh)
assert.Len(t, p.svr.openRequests, 0)
}
func TestRequestCacheState(t *testing.T) {
// test operation that uses open/close
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
assert.Len(t, p.svr.openRequests, 0)
// test operation that doesn't open/close
err = p.cli.Remove("/foo")
assert.Nil(t, err)
assert.Len(t, p.svr.openRequests, 0)
}
func putTestFile(cli *Client, path, content string) (int, error) {
w, err := cli.Create(path)
if err == nil {
defer w.Close()
return w.Write([]byte(content))
}
return 0, err
}
func TestRequestWrite(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
n, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
assert.Equal(t, 5, n)
r := p.testHandler()
f, err := r.fetch("/foo")
assert.Nil(t, err)
assert.False(t, f.isdir)
assert.Equal(t, f.content, []byte("hello"))
}
// needs fail check
func TestRequestFilename(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
r := p.testHandler()
f, err := r.fetch("/foo")
assert.Nil(t, err)
assert.Equal(t, f.Name(), "foo")
}
func TestRequestRead(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
rf, err := p.cli.Open("/foo")
assert.Nil(t, err)
defer rf.Close()
contents := make([]byte, 5)
n, err := rf.Read(contents)
if err != nil && err != io.EOF {
t.Fatalf("err: %v", err)
}
assert.Equal(t, 5, n)
assert.Equal(t, "hello", string(contents[0:5]))
}
func TestRequestReadFail(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
rf, err := p.cli.Open("/foo")
assert.Nil(t, err)
contents := make([]byte, 5)
n, err := rf.Read(contents)
assert.Equal(t, n, 0)
assert.Exactly(t, os.ErrNotExist, err)
}
func TestRequestOpen(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
fh, err := p.cli.Open("foo")
assert.Nil(t, err)
err = fh.Close()
assert.Nil(t, err)
}
func TestRequestMkdir(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
err := p.cli.Mkdir("/foo")
assert.Nil(t, err)
r := p.testHandler()
f, err := r.fetch("/foo")
assert.Nil(t, err)
assert.True(t, f.isdir)
}
func TestRequestRemove(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
r := p.testHandler()
_, err = r.fetch("/foo")
assert.Nil(t, err)
err = p.cli.Remove("/foo")
assert.Nil(t, err)
_, err = r.fetch("/foo")
assert.Equal(t, err, os.ErrNotExist)
}
func TestRequestRename(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
r := p.testHandler()
_, err = r.fetch("/foo")
assert.Nil(t, err)
err = p.cli.Rename("/foo", "/bar")
assert.Nil(t, err)
_, err = r.fetch("/bar")
assert.Nil(t, err)
_, err = r.fetch("/foo")
assert.Equal(t, err, os.ErrNotExist)
}
func TestRequestRenameFail(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
_, err = putTestFile(p.cli, "/bar", "goodbye")
assert.Nil(t, err)
err = p.cli.Rename("/foo", "/bar")
assert.IsType(t, &StatusError{}, err)
}
func TestRequestStat(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
fi, err := p.cli.Stat("/foo")
assert.Equal(t, fi.Name(), "foo")
assert.Equal(t, fi.Size(), int64(5))
assert.Equal(t, fi.Mode(), os.FileMode(0644))
assert.NoError(t, testOsSys(fi.Sys()))
}
// NOTE: Setstat is a noop in the request server tests, but we want to test
// that is does nothing without crapping out.
func TestRequestSetstat(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
mode := os.FileMode(0644)
err = p.cli.Chmod("/foo", mode)
assert.Nil(t, err)
fi, err := p.cli.Stat("/foo")
assert.Nil(t, err)
assert.Equal(t, fi.Name(), "foo")
assert.Equal(t, fi.Size(), int64(5))
assert.Equal(t, fi.Mode(), os.FileMode(0644))
assert.NoError(t, testOsSys(fi.Sys()))
}
func TestRequestFstat(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
fp, err := p.cli.Open("/foo")
assert.Nil(t, err)
fi, err := fp.Stat()
assert.Nil(t, err)
assert.Equal(t, fi.Name(), "foo")
assert.Equal(t, fi.Size(), int64(5))
assert.Equal(t, fi.Mode(), os.FileMode(0644))
assert.NoError(t, testOsSys(fi.Sys()))
}
func TestRequestStatFail(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
fi, err := p.cli.Stat("/foo")
assert.Nil(t, fi)
assert.True(t, os.IsNotExist(err))
}
func TestRequestSymlink(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
err = p.cli.Symlink("/foo", "/bar")
assert.Nil(t, err)
r := p.testHandler()
fi, err := r.fetch("/bar")
assert.Nil(t, err)
assert.True(t, fi.Mode()&os.ModeSymlink == os.ModeSymlink)
}
func TestRequestSymlinkFail(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
err := p.cli.Symlink("/foo", "/bar")
assert.True(t, os.IsNotExist(err))
}
func TestRequestReadlink(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
err = p.cli.Symlink("/foo", "/bar")
assert.Nil(t, err)
rl, err := p.cli.ReadLink("/bar")
assert.Nil(t, err)
assert.Equal(t, "foo", rl)
}
func TestRequestReaddir(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
for i := 0; i < 100; i++ {
fname := fmt.Sprintf("/foo_%02d", i)
_, err := putTestFile(p.cli, fname, fname)
assert.Nil(t, err)
}
di, err := p.cli.ReadDir("/")
assert.Nil(t, err)
assert.Len(t, di, 100)
names := []string{di[18].Name(), di[81].Name()}
assert.Equal(t, []string{"foo_18", "foo_81"}, names)
}

View File

@ -0,0 +1,23 @@
// +build !windows
package sftp
import (
"errors"
"syscall"
)
func fakeFileInfoSys() interface{} {
return &syscall.Stat_t{Uid: 65534, Gid: 65534}
}
func testOsSys(sys interface{}) error {
fstat := sys.(*FileStat)
if fstat.UID != uint32(65534) {
return errors.New("Uid failed to match.")
}
if fstat.GID != uint32(65534) {
return errors.New("Gid failed to match:")
}
return nil
}

Some files were not shown because too many files have changed in this diff Show More