2
2
mirror of https://github.com/octoleo/restic.git synced 2025-01-22 22:58:26 +00:00

Merge pull request #922 from restic/update-deps

Update vendored dependencies
This commit is contained in:
Alexander Neumann 2017-04-19 18:30:41 +02:00
commit 5b80cb8b6b
264 changed files with 26514 additions and 6157 deletions

View File

@ -49,7 +49,7 @@ func specialDir(name string) bool {
// excludePath returns true if the file should not be copied to the new GOPATH.
func excludePath(name string) bool {
ext := path.Ext(name)
if ext == ".go" || ext == ".s" {
if ext == ".go" || ext == ".s" || ext == ".h" {
return false
}

View File

@ -12,8 +12,8 @@ import (
// Getxattr retrieves extended attribute data associated with path.
func Getxattr(path, name string) ([]byte, error) {
b, e := xattr.Getxattr(path, name)
if err, ok := e.(*xattr.XAttrError); ok && err.Err == syscall.ENOTSUP {
b, e := xattr.Get(path, name)
if err, ok := e.(*xattr.Error); ok && err.Err == syscall.ENOTSUP {
return nil, nil
}
return b, errors.Wrap(e, "Getxattr")
@ -22,8 +22,8 @@ func Getxattr(path, name string) ([]byte, error) {
// Listxattr retrieves a list of names of extended attributes associated with the
// given path in the file system.
func Listxattr(path string) ([]string, error) {
s, e := xattr.Listxattr(path)
if err, ok := e.(*xattr.XAttrError); ok && err.Err == syscall.ENOTSUP {
s, e := xattr.List(path)
if err, ok := e.(*xattr.Error); ok && err.Err == syscall.ENOTSUP {
return nil, nil
}
return s, errors.Wrap(e, "Listxattr")
@ -31,8 +31,8 @@ func Listxattr(path string) ([]string, error) {
// Setxattr associates name and data together as an attribute of path.
func Setxattr(path, name string, data []byte) error {
e := xattr.Setxattr(path, name, data)
if err, ok := e.(*xattr.XAttrError); ok && err.Err == syscall.ENOTSUP {
e := xattr.Set(path, name, data)
if err, ok := e.(*xattr.Error); ok && err.Err == syscall.ENOTSUP {
return nil
}
return errors.Wrap(e, "Setxattr")

36
vendor/manifest vendored
View File

@ -28,14 +28,14 @@
{
"importpath": "github.com/minio/minio-go",
"repository": "https://github.com/minio/minio-go",
"revision": "b1674741d196d5d79486d7c1645ed6ded902b712",
"branch": "master"
"revision": "dcaae9ec4d0b0a81d17f22f6d7a186491f6a55ec",
"branch": "HEAD"
},
{
"importpath": "github.com/pkg/errors",
"repository": "https://github.com/pkg/errors",
"revision": "17b591df37844cde689f4d5813e5cea0927d8dd2",
"branch": "master"
"revision": "645ef00459ed84a119197bfb8d8205042c6df63d",
"branch": "HEAD"
},
{
"importpath": "github.com/pkg/profile",
@ -52,80 +52,80 @@
{
"importpath": "github.com/pkg/xattr",
"repository": "https://github.com/pkg/xattr",
"revision": "b867675798fa7708a444945602b452ca493f2272",
"branch": "master"
"revision": "858d49c224b241ba9393e20f521f6a76f52dd482",
"branch": "HEAD"
},
{
"importpath": "github.com/restic/chunker",
"repository": "https://github.com/restic/chunker",
"revision": "49e9b5212b022a1ab373faf981ed4f2fc807502a",
"branch": "master"
"revision": "bb2ecf9a98e35a0b336ffc23fc515fb6e7961577",
"branch": "HEAD"
},
{
"importpath": "github.com/spf13/cobra",
"repository": "https://github.com/spf13/cobra",
"revision": "9c28e4bbd74e5c3ed7aacbc552b2cab7cfdfe744",
"revision": "b6cb3958937245a12d4d7728be080a6c758f4136",
"branch": "master"
},
{
"importpath": "github.com/spf13/pflag",
"repository": "https://github.com/spf13/pflag",
"revision": "c7e63cf4530bcd3ba943729cee0efeff2ebea63f",
"revision": "2300d0f8576fe575f71aaa5b9bbe4e1b0dc2eb51",
"branch": "master"
},
{
"importpath": "golang.org/x/crypto/curve25519",
"repository": "https://go.googlesource.com/crypto",
"revision": "81372b2fc2f10bef2a7f338da115c315a56b2726",
"revision": "efac7f277b17c19894091e358c6130cb6bd51117",
"branch": "master",
"path": "/curve25519"
},
{
"importpath": "golang.org/x/crypto/ed25519",
"repository": "https://go.googlesource.com/crypto",
"revision": "81372b2fc2f10bef2a7f338da115c315a56b2726",
"revision": "efac7f277b17c19894091e358c6130cb6bd51117",
"branch": "master",
"path": "/ed25519"
},
{
"importpath": "golang.org/x/crypto/pbkdf2",
"repository": "https://go.googlesource.com/crypto",
"revision": "81372b2fc2f10bef2a7f338da115c315a56b2726",
"revision": "efac7f277b17c19894091e358c6130cb6bd51117",
"branch": "master",
"path": "/pbkdf2"
},
{
"importpath": "golang.org/x/crypto/poly1305",
"repository": "https://go.googlesource.com/crypto",
"revision": "5f31782cfb2b6373211f8f9fbf31283fa234b570",
"revision": "efac7f277b17c19894091e358c6130cb6bd51117",
"branch": "master",
"path": "/poly1305"
},
{
"importpath": "golang.org/x/crypto/scrypt",
"repository": "https://go.googlesource.com/crypto",
"revision": "81372b2fc2f10bef2a7f338da115c315a56b2726",
"revision": "efac7f277b17c19894091e358c6130cb6bd51117",
"branch": "master",
"path": "/scrypt"
},
{
"importpath": "golang.org/x/crypto/ssh",
"repository": "https://go.googlesource.com/crypto",
"revision": "81372b2fc2f10bef2a7f338da115c315a56b2726",
"revision": "efac7f277b17c19894091e358c6130cb6bd51117",
"branch": "master",
"path": "/ssh"
},
{
"importpath": "golang.org/x/net/context",
"repository": "https://go.googlesource.com/net",
"revision": "de35ec43e7a9aabd6a9c54d2898220ea7e44de7d",
"revision": "5602c733f70afc6dcec6766be0d5034d4c4f14de",
"branch": "master",
"path": "/context"
},
{
"importpath": "golang.org/x/sys/unix",
"repository": "https://go.googlesource.com/sys",
"revision": "30de6d19a3bd89a5f38ae4028e23aaa5582648af",
"revision": "f3918c30c5c2cb527c0b071a27c35120a6c0719a",
"branch": "master",
"path": "/unix"
}

View File

@ -1,5 +1,6 @@
# Minio Golang Library for Amazon S3 Compatible Cloud Storage [![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/Minio/minio?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
The Minio Golang Client SDK provides simple APIs to access any Amazon S3 compatible object storage server.
# Minio Go Client SDK for Amazon S3 Compatible Cloud Storage [![Slack](https://slack.minio.io/slack?type=svg)](https://slack.minio.io)
The Minio Go Client SDK provides simple APIs to access any Amazon S3 compatible object storage.
**Supported cloud storage providers:**
@ -14,22 +15,21 @@ The Minio Golang Client SDK provides simple APIs to access any Amazon S3 compati
- Ceph Object Gateway
- Riak CS
This quickstart guide will show you how to install the Minio client SDK, connect to Minio, and provide a walkthrough of a simple file uploader. For a complete list of APIs and examples, please take a look at the [Golang Client API Reference](https://docs.minio.io/docs/golang-client-api-reference).
This quickstart guide will show you how to install the Minio client SDK, connect to Minio, and provide a walkthrough for a simple file uploader. For a complete list of APIs and examples, please take a look at the [Go Client API Reference](https://docs.minio.io/docs/golang-client-api-reference).
This document assumes that you have a working [Golang setup](https://docs.minio.io/docs/how-to-install-golang).
This document assumes that you have a working [Go development environment](https://docs.minio.io/docs/how-to-install-golang).
## Download from Github
```sh
$ go get -u github.com/minio/minio-go
go get -u github.com/minio/minio-go
```
## Initialize Minio Client
You need four items to connect to Minio object storage server.
Minio client requires the following four parameters specified to connect to an Amazon S3 compatible object storage.
| Parameter | Description|
@ -68,7 +68,7 @@ func main() {
## Quick Start Example - File Uploader
This example program connects to an object storage server, makes a bucket on the server and then uploads a file to the bucket.
This example program connects to an object storage server, creates a bucket and uploads a file to the bucket.
@ -132,11 +132,11 @@ func main() {
```sh
$ go run file-uploader.go
go run file-uploader.go
2016/08/13 17:03:28 Successfully created mymusic
2016/08/13 17:03:40 Successfully uploaded golden-oldies.zip of size 16253413
$ mc ls play/mymusic/
mc ls play/mymusic/
[2016-05-27 16:02:16 PDT] 17MiB golden-oldies.zip
```
@ -161,6 +161,7 @@ The full API Reference is available here.
* [`SetBucketPolicy`](https://docs.minio.io/docs/golang-client-api-reference#SetBucketPolicy)
* [`GetBucketPolicy`](https://docs.minio.io/docs/golang-client-api-reference#GetBucketPolicy)
* [`ListBucketPolicies`](https://docs.minio.io/docs/golang-client-api-reference#ListBucketPolicies)
### API Reference : Bucket notification Operations
@ -173,14 +174,15 @@ The full API Reference is available here.
* [`FPutObject`](https://docs.minio.io/docs/golang-client-api-reference#FPutObject)
* [`FGetObject`](https://docs.minio.io/docs/golang-client-api-reference#FPutObject)
* [`CopyObject`](https://docs.minio.io/docs/golang-client-api-reference#CopyObject)
### API Reference : Object Operations
* [`GetObject`](https://docs.minio.io/docs/golang-client-api-reference#GetObject)
* [`PutObject`](https://docs.minio.io/docs/golang-client-api-reference#PutObject)
* [`StatObject`](https://docs.minio.io/docs/golang-client-api-reference#StatObject)
* [`CopyObject`](https://docs.minio.io/docs/golang-client-api-reference#CopyObject)
* [`RemoveObject`](https://docs.minio.io/docs/golang-client-api-reference#RemoveObject)
* [`RemoveObjects`](https://docs.minio.io/docs/golang-client-api-reference#RemoveObjects)
* [`RemoveIncompleteUpload`](https://docs.minio.io/docs/golang-client-api-reference#RemoveIncompleteUpload)
### API Reference : Presigned Operations
@ -189,44 +191,52 @@ The full API Reference is available here.
* [`PresignedPutObject`](https://docs.minio.io/docs/golang-client-api-reference#PresignedPutObject)
* [`PresignedPostPolicy`](https://docs.minio.io/docs/golang-client-api-reference#PresignedPostPolicy)
### API Reference : Client custom settings
* [`SetAppInfo`](http://docs.minio.io/docs/golang-client-api-reference#SetAppInfo)
* [`SetCustomTransport`](http://docs.minio.io/docs/golang-client-api-reference#SetCustomTransport)
* [`TraceOn`](http://docs.minio.io/docs/golang-client-api-reference#TraceOn)
* [`TraceOff`](http://docs.minio.io/docs/golang-client-api-reference#TraceOff)
## Full Examples
#### Full Examples : Bucket Operations
* [listbuckets.go](https://github.com/minio/minio-go/blob/master/examples/s3/listbuckets.go)
* [listobjects.go](https://github.com/minio/minio-go/blob/master/examples/s3/listobjects.go)
* [bucketexists.go](https://github.com/minio/minio-go/blob/master/examples/s3/bucketexists.go)
* [makebucket.go](https://github.com/minio/minio-go/blob/master/examples/s3/makebucket.go)
* [listbuckets.go](https://github.com/minio/minio-go/blob/master/examples/s3/listbuckets.go)
* [bucketexists.go](https://github.com/minio/minio-go/blob/master/examples/s3/bucketexists.go)
* [removebucket.go](https://github.com/minio/minio-go/blob/master/examples/s3/removebucket.go)
* [listobjects.go](https://github.com/minio/minio-go/blob/master/examples/s3/listobjects.go)
* [listobjectsV2.go](https://github.com/minio/minio-go/blob/master/examples/s3/listobjectsV2.go)
* [listincompleteuploads.go](https://github.com/minio/minio-go/blob/master/examples/s3/listincompleteuploads.go)
#### Full Examples : Bucket policy Operations
* [setbucketpolicy.go](https://github.com/minio/minio-go/blob/master/examples/s3/setbucketpolicy.go)
* [getbucketpolicy.go](https://github.com/minio/minio-go/blob/master/examples/s3/getbucketpolicy.go)
* [listbucketpolicies.go](https://github.com/minio/minio-go/blob/master/examples/s3/listbucketpolicies.go)
#### Full Examples : Bucket notification Operations
* [setbucketnotification.go](https://github.com/minio/minio-go/blob/master/examples/s3/setbucketnotification.go)
* [getbucketnotification.go](https://github.com/minio/minio-go/blob/master/examples/s3/getbucketnotification.go)
* [deletebucketnotification.go](https://github.com/minio/minio-go/blob/master/examples/s3/deletebucketnotification.go)
* [removeallbucketnotification.go](https://github.com/minio/minio-go/blob/master/examples/s3/removeallbucketnotification.go)
* [listenbucketnotification.go](https://github.com/minio/minio-go/blob/master/examples/minio/listenbucketnotification.go) (Minio Extension)
#### Full Examples : File Object Operations
* [fputobject.go](https://github.com/minio/minio-go/blob/master/examples/s3/fputobject.go)
* [fgetobject.go](https://github.com/minio/minio-go/blob/master/examples/s3/fgetobject.go)
* [copyobject.go](https://github.com/minio/minio-go/blob/master/examples/s3/copyobject.go)
#### Full Examples : Object Operations
* [putobject.go](https://github.com/minio/minio-go/blob/master/examples/s3/putobject.go)
* [getobject.go](https://github.com/minio/minio-go/blob/master/examples/s3/getobject.go)
* [listobjects.go](https://github.com/minio/minio-go/blob/master/examples/s3/listobjects.go)
* [listobjectsV2.go](https://github.com/minio/minio-go/blob/master/examples/s3/listobjectsV2.go)
* [removeobject.go](https://github.com/minio/minio-go/blob/master/examples/s3/removeobject.go)
* [statobject.go](https://github.com/minio/minio-go/blob/master/examples/s3/statobject.go)
* [copyobject.go](https://github.com/minio/minio-go/blob/master/examples/s3/copyobject.go)
* [removeobject.go](https://github.com/minio/minio-go/blob/master/examples/s3/removeobject.go)
* [removeincompleteupload.go](https://github.com/minio/minio-go/blob/master/examples/s3/removeincompleteupload.go)
* [removeobjects.go](https://github.com/minio/minio-go/blob/master/examples/s3/removeobjects.go)
#### Full Examples : Presigned Operations
* [presignedgetobject.go](https://github.com/minio/minio-go/blob/master/examples/s3/presignedgetobject.go)
@ -235,7 +245,7 @@ The full API Reference is available here.
## Explore Further
* [Complete Documentation](https://docs.minio.io)
* [Minio Golang Client SDK API Reference](https://docs.minio.io/docs/golang-client-api-reference)
* [Minio Go Client SDK API Reference](https://docs.minio.io/docs/golang-client-api-reference)
* [Go Music Player App- Full Application Example ](https://docs.minio.io/docs/go-music-player-app)
## Contribute

View File

@ -16,7 +16,10 @@
package minio
import "time"
import (
"net/http"
"time"
)
// BucketInfo container for bucket metadata.
type BucketInfo struct {
@ -38,6 +41,10 @@ type ObjectInfo struct {
Size int64 `json:"size"` // Size in bytes of the object.
ContentType string `json:"contentType"` // A standard MIME type describing the format of the object data.
// Collection of additional metadata on the object.
// eg: x-amz-meta-*, content-encoding etc.
Metadata http.Header `json:"metadata"`
// Owner name.
Owner struct {
DisplayName string `json:"name"`

View File

@ -149,6 +149,16 @@ func httpRespToErrorResponse(resp *http.Response, bucketName, objectName string)
return errResp
}
// ErrTransferAccelerationBucket - bucket name is invalid to be used with transfer acceleration.
func ErrTransferAccelerationBucket(bucketName string) error {
msg := fmt.Sprintf("The name of the bucket used for Transfer Acceleration must be DNS-compliant and must not contain periods (\".\").")
return ErrorResponse{
Code: "InvalidArgument",
Message: msg,
BucketName: bucketName,
}
}
// ErrEntityTooLarge - Input size is larger than supported maximum.
func ErrEntityTooLarge(totalSize, maxObjectSize int64, bucketName, objectName string) error {
msg := fmt.Sprintf("Your proposed upload size %d exceeds the maximum allowed object size %d for single PUT operation.", totalSize, maxObjectSize)
@ -201,16 +211,6 @@ func ErrInvalidObjectName(message string) error {
}
}
// ErrInvalidParts - Invalid number of parts.
func ErrInvalidParts(expectedParts, uploadedParts int) error {
msg := fmt.Sprintf("Unexpected number of parts found Want %d, Got %d", expectedParts, uploadedParts)
return ErrorResponse{
Code: "InvalidParts",
Message: msg,
RequestID: "minio",
}
}
// ErrInvalidObjectPrefix - Invalid object prefix response is
// similar to object name response.
var ErrInvalidObjectPrefix = ErrInvalidObjectName

View File

@ -249,20 +249,6 @@ func TestErrInvalidObjectName(t *testing.T) {
}
}
// Test validates 'ErrInvalidParts' error response.
func TestErrInvalidParts(t *testing.T) {
msg := fmt.Sprintf("Unexpected number of parts found Want %d, Got %d", 10, 9)
expectedResult := ErrorResponse{
Code: "InvalidParts",
Message: msg,
RequestID: "minio",
}
actualResult := ErrInvalidParts(10, 9)
if !reflect.DeepEqual(expectedResult, actualResult) {
t.Errorf("Expected result to be '%+v', but instead got '%+v'", expectedResult, actualResult)
}
}
// Test validates 'ErrInvalidArgument' response.
func TestErrInvalidArgument(t *testing.T) {
expectedResult := ErrorResponse{

View File

@ -73,7 +73,9 @@ func (c Client) GetObject(bucketName, objectName string) (*Object, error) {
if req.isReadAt {
// If this is a ReadAt request only get the specified range.
// Range is set with respect to the offset and length of the buffer requested.
httpReader, objectInfo, err = c.getObject(bucketName, objectName, req.Offset, int64(len(req.Buffer)))
// Do not set objectInfo from the first readAt request because it will not get
// the whole object.
httpReader, _, err = c.getObject(bucketName, objectName, req.Offset, int64(len(req.Buffer)))
} else {
// First request is a Read request.
httpReader, objectInfo, err = c.getObject(bucketName, objectName, req.Offset, 0)
@ -115,6 +117,19 @@ func (c Client) GetObject(bucketName, objectName string) (*Object, error) {
objectInfo: objectInfo,
}
}
} else if req.settingObjectInfo { // Request is just to get objectInfo.
objectInfo, err := c.StatObject(bucketName, objectName)
if err != nil {
resCh <- getResponse{
Error: err,
}
// Exit the goroutine.
return
}
// Send back the objectInfo.
resCh <- getResponse{
objectInfo: objectInfo,
}
} else {
// Offset changes fetch the new object at an Offset.
// Because the httpReader may not be set by the first
@ -132,7 +147,7 @@ func (c Client) GetObject(bucketName, objectName string) (*Object, error) {
// Range is set with respect to the offset and length of the buffer requested.
httpReader, _, err = c.getObject(bucketName, objectName, req.Offset, int64(len(req.Buffer)))
} else {
httpReader, _, err = c.getObject(bucketName, objectName, req.Offset, 0)
httpReader, objectInfo, err = c.getObject(bucketName, objectName, req.Offset, 0)
}
if err != nil {
resCh <- getResponse{
@ -152,9 +167,10 @@ func (c Client) GetObject(bucketName, objectName string) (*Object, error) {
}
// Reply back how much was read.
resCh <- getResponse{
Size: int(size),
Error: err,
didRead: true,
Size: int(size),
Error: err,
didRead: true,
objectInfo: objectInfo,
}
}
}
@ -168,13 +184,14 @@ func (c Client) GetObject(bucketName, objectName string) (*Object, error) {
// get request message container to communicate with internal
// go-routine.
type getRequest struct {
Buffer []byte
Offset int64 // readAt offset.
DidOffsetChange bool // Tracks the offset changes for Seek requests.
beenRead bool // Determines if this is the first time an object is being read.
isReadAt bool // Determines if this request is a request to a specific range
isReadOp bool // Determines if this request is a Read or Read/At request.
isFirstReq bool // Determines if this request is the first time an object is being accessed.
Buffer []byte
Offset int64 // readAt offset.
DidOffsetChange bool // Tracks the offset changes for Seek requests.
beenRead bool // Determines if this is the first time an object is being read.
isReadAt bool // Determines if this request is a request to a specific range
isReadOp bool // Determines if this request is a Read or Read/At request.
isFirstReq bool // Determines if this request is the first time an object is being accessed.
settingObjectInfo bool // Determines if this request is to set the objectInfo of an object.
}
// get response message container to reply back for the request.
@ -195,10 +212,12 @@ type Object struct {
reqCh chan<- getRequest
resCh <-chan getResponse
doneCh chan<- struct{}
prevOffset int64
currOffset int64
objectInfo ObjectInfo
// Ask lower level to initiate data fetching based on currOffset
seekData bool
// Keeps track of closed call.
isClosed bool
@ -210,6 +229,9 @@ type Object struct {
// Keeps track of if this object has been read yet.
beenRead bool
// Keeps track of if objectInfo has been set yet.
objectInfoSet bool
}
// doGetRequest - sends and blocks on the firstReqCh and reqCh of an object.
@ -221,11 +243,15 @@ func (o *Object) doGetRequest(request getRequest) (getResponse, error) {
response := <-o.resCh
// This was the first request.
if !o.isStarted {
// Set objectInfo for first time.
o.objectInfo = response.objectInfo
// The object has been operated on.
o.isStarted = true
}
// Set the objectInfo if the request was not readAt
// and it hasn't been set before.
if !o.objectInfoSet && !request.isReadAt {
o.objectInfo = response.objectInfo
o.objectInfoSet = true
}
// Set beenRead only if it has not been set before.
if !o.beenRead {
o.beenRead = response.didRead
@ -235,6 +261,9 @@ func (o *Object) doGetRequest(request getRequest) (getResponse, error) {
return response, response.Error
}
// Data are ready on the wire, no need to reinitiate connection in lower level
o.seekData = false
return response, nil
}
@ -243,8 +272,6 @@ func (o *Object) doGetRequest(request getRequest) (getResponse, error) {
func (o *Object) setOffset(bytesRead int64) error {
// Update the currentOffset.
o.currOffset += bytesRead
// Save the current offset as previous offset.
o.prevOffset = o.currOffset
if o.currOffset >= o.objectInfo.Size {
return io.EOF
@ -252,7 +279,7 @@ func (o *Object) setOffset(bytesRead int64) error {
return nil
}
// Read reads up to len(p) bytes into p. It returns the number of
// Read reads up to len(b) bytes into b. It returns the number of
// bytes read (0 <= n <= len(p)) and any error encountered. Returns
// io.EOF upon end of file.
func (o *Object) Read(b []byte) (n int, err error) {
@ -280,27 +307,14 @@ func (o *Object) Read(b []byte) (n int, err error) {
readReq.isFirstReq = true
}
// Verify if offset has changed and currOffset is greater than
// previous offset. Perhaps due to Seek().
offsetChange := o.prevOffset - o.currOffset
if offsetChange < 0 {
offsetChange = -offsetChange
}
if offsetChange > 0 {
// Fetch the new reader at the current offset again.
readReq.Offset = o.currOffset
readReq.DidOffsetChange = true
} else {
// No offset changes no need to fetch new reader, continue
// reading.
readReq.DidOffsetChange = false
readReq.Offset = 0
}
// Ask to establish a new data fetch routine based on seekData flag
readReq.DidOffsetChange = o.seekData
readReq.Offset = o.currOffset
// Send and receive from the first request.
response, err := o.doGetRequest(readReq)
if err != nil {
// Save the error.
if err != nil && err != io.EOF {
// Save the error for future calls.
o.prevErr = err
return response.Size, err
}
@ -309,14 +323,18 @@ func (o *Object) Read(b []byte) (n int, err error) {
bytesRead := int64(response.Size)
// Set the new offset.
err = o.setOffset(bytesRead)
if err != nil {
return response.Size, err
oerr := o.setOffset(bytesRead)
if oerr != nil {
// Save the error for future calls.
o.prevErr = oerr
return response.Size, oerr
}
return response.Size, nil
// Return the response.
return response.Size, err
}
// Stat returns the ObjectInfo structure describing object.
// Stat returns the ObjectInfo structure describing Object.
func (o *Object) Stat() (ObjectInfo, error) {
if o == nil {
return ObjectInfo{}, ErrInvalidArgument("Object is nil")
@ -325,16 +343,15 @@ func (o *Object) Stat() (ObjectInfo, error) {
o.mutex.Lock()
defer o.mutex.Unlock()
if o.prevErr != nil || o.isClosed {
if o.prevErr != nil && o.prevErr != io.EOF || o.isClosed {
return ObjectInfo{}, o.prevErr
}
// This is the first request.
if !o.isStarted {
if !o.isStarted || !o.objectInfoSet {
statReq := getRequest{
isReadOp: false, // This is a Stat not a Read/ReadAt.
Offset: 0,
isFirstReq: true,
isFirstReq: !o.isStarted,
settingObjectInfo: !o.objectInfoSet,
}
// Send the request and get the response.
@ -365,8 +382,9 @@ func (o *Object) ReadAt(b []byte, offset int64) (n int, err error) {
if o.prevErr != nil || o.isClosed {
return 0, o.prevErr
}
// Can only compare offsets to size when size has been set.
if o.isStarted {
if o.objectInfoSet {
// If offset is negative than we return io.EOF.
// If offset is greater than or equal to object size we return io.EOF.
if offset >= o.objectInfo.Size || offset < 0 {
@ -383,6 +401,7 @@ func (o *Object) ReadAt(b []byte, offset int64) (n int, err error) {
Offset: offset, // Set the offset.
Buffer: b,
}
// Alert that this is the first request.
if !o.isStarted {
readAtReq.isFirstReq = true
@ -390,21 +409,29 @@ func (o *Object) ReadAt(b []byte, offset int64) (n int, err error) {
// Send and receive from the first request.
response, err := o.doGetRequest(readAtReq)
if err != nil {
if err != nil && err != io.EOF {
// Save the error.
o.prevErr = err
return 0, err
return response.Size, err
}
// Bytes read.
bytesRead := int64(response.Size)
// Update the offsets.
err = o.setOffset(bytesRead)
if err != nil {
return response.Size, err
// There is no valid objectInfo yet
// to compare against for EOF.
if !o.objectInfoSet {
// Update the currentOffset.
o.currOffset += bytesRead
} else {
// If this was not the first request update
// the offsets and compare against objectInfo
// for EOF.
oerr := o.setOffset(bytesRead)
if oerr != nil {
o.prevErr = oerr
return response.Size, oerr
}
}
return response.Size, nil
return response.Size, err
}
// Seek sets the offset for the next Read or Write to offset,
@ -439,7 +466,7 @@ func (o *Object) Seek(offset int64, whence int) (n int64, err error) {
// This is the first request. So before anything else
// get the ObjectInfo.
if !o.isStarted {
if !o.isStarted || !o.objectInfoSet {
// Create the new Seek request.
seekReq := getRequest{
isReadOp: false,
@ -454,8 +481,6 @@ func (o *Object) Seek(offset int64, whence int) (n int64, err error) {
return 0, err
}
}
// Save current offset as previous offset.
o.prevOffset = o.currOffset
// Switch through whence.
switch whence {
@ -489,6 +514,10 @@ func (o *Object) Seek(offset int64, whence int) (n int64, err error) {
if o.prevErr == io.EOF {
o.prevErr = nil
}
// Ask lower level to fetch again from source
o.seekData = true
// Return the effective offset.
return o.currOffset, nil
}

View File

@ -41,7 +41,23 @@ func (c Client) GetBucketPolicy(bucketName, objectPrefix string) (bucketPolicy p
return policy.GetPolicy(policyInfo.Statements, bucketName, objectPrefix), nil
}
// Request server for policy.
// ListBucketPolicies - list all policies for a given prefix and all its children.
func (c Client) ListBucketPolicies(bucketName, objectPrefix string) (bucketPolicies map[string]policy.BucketPolicy, err error) {
// Input validation.
if err := isValidBucketName(bucketName); err != nil {
return map[string]policy.BucketPolicy{}, err
}
if err := isValidObjectPrefix(objectPrefix); err != nil {
return map[string]policy.BucketPolicy{}, err
}
policyInfo, err := c.getBucketPolicy(bucketName, objectPrefix)
if err != nil {
return map[string]policy.BucketPolicy{}, err
}
return policy.GetPolicies(policyInfo.Statements, bucketName), nil
}
// Request server for current bucket policy.
func (c Client) getBucketPolicy(bucketName string, objectPrefix string) (policy.BucketAccessPolicy, error) {
// Get resources properly escaped and lined up before
// using them in http request.

View File

@ -84,6 +84,8 @@ func (c Client) ListObjectsV2(bucketName, objectPrefix string, recursive bool, d
// If recursive we do not delimit.
delimiter = ""
}
// Return object owner information by default
fetchOwner := true
// Validate bucket name.
if err := isValidBucketName(bucketName); err != nil {
defer close(objectStatCh)
@ -108,7 +110,7 @@ func (c Client) ListObjectsV2(bucketName, objectPrefix string, recursive bool, d
var continuationToken string
for {
// Get list of objects a maximum of 1000 per request.
result, err := c.listObjectsV2Query(bucketName, objectPrefix, continuationToken, delimiter, 1000)
result, err := c.listObjectsV2Query(bucketName, objectPrefix, continuationToken, fetchOwner, delimiter, 1000)
if err != nil {
objectStatCh <- ObjectInfo{
Err: err,
@ -166,7 +168,7 @@ func (c Client) ListObjectsV2(bucketName, objectPrefix string, recursive bool, d
// ?delimiter - A delimiter is a character you use to group keys.
// ?prefix - Limits the response to keys that begin with the specified prefix.
// ?max-keys - Sets the maximum number of keys returned in the response body.
func (c Client) listObjectsV2Query(bucketName, objectPrefix, continuationToken, delimiter string, maxkeys int) (listBucketV2Result, error) {
func (c Client) listObjectsV2Query(bucketName, objectPrefix, continuationToken string, fetchOwner bool, delimiter string, maxkeys int) (listBucketV2Result, error) {
// Validate bucket name.
if err := isValidBucketName(bucketName); err != nil {
return listBucketV2Result{}, err
@ -195,6 +197,11 @@ func (c Client) listObjectsV2Query(bucketName, objectPrefix, continuationToken,
urlValues.Set("delimiter", delimiter)
}
// Fetch owner when listing
if fetchOwner {
urlValues.Set("fetch-owner", "true")
}
// maxkeys should default to 1000 or less.
if maxkeys == 0 || maxkeys > 1000 {
maxkeys = 1000
@ -475,6 +482,7 @@ func (c Client) listIncompleteUploads(bucketName, objectPrefix string, recursive
objectMultipartStatCh <- ObjectMultipartInfo{
Err: err,
}
continue
}
}
select {

View File

@ -22,6 +22,9 @@ import (
"io"
"net/http"
"net/url"
"time"
"github.com/minio/minio-go/pkg/s3utils"
)
// GetBucketNotification - get bucket notification at a given path.
@ -120,7 +123,7 @@ type NotificationInfo struct {
}
// ListenBucketNotification - listen on bucket notifications.
func (c Client) ListenBucketNotification(bucketName string, accountArn Arn, doneCh <-chan struct{}) <-chan NotificationInfo {
func (c Client) ListenBucketNotification(bucketName, prefix, suffix string, events []string, doneCh <-chan struct{}) <-chan NotificationInfo {
notificationInfoCh := make(chan NotificationInfo, 1)
// Only success, start a routine to start reading line by line.
go func(notificationInfoCh chan<- NotificationInfo) {
@ -135,7 +138,7 @@ func (c Client) ListenBucketNotification(bucketName string, accountArn Arn, done
}
// Check ARN partition to verify if listening bucket is supported
if accountArn.Partition != "minio" {
if s3utils.IsAmazonEndpoint(c.endpointURL) || s3utils.IsGoogleEndpoint(c.endpointURL) {
notificationInfoCh <- NotificationInfo{
Err: ErrAPINotSupported("Listening bucket notification is specific only to `minio` partitions"),
}
@ -143,9 +146,18 @@ func (c Client) ListenBucketNotification(bucketName string, accountArn Arn, done
}
// Continously run and listen on bucket notification.
for {
// Create a done channel to control 'ListObjects' go routine.
retryDoneCh := make(chan struct{}, 1)
// Indicate to our routine to exit cleanly upon return.
defer close(retryDoneCh)
// Wait on the jitter retry loop.
for range c.newRetryTimerContinous(time.Second, time.Second*30, MaxJitter, retryDoneCh) {
urlValues := make(url.Values)
urlValues.Set("notificationARN", accountArn.String())
urlValues.Set("prefix", prefix)
urlValues.Set("suffix", suffix)
urlValues["events"] = events
// Execute GET on bucket to list objects.
resp, err := c.executeMethod("GET", requestMetadata{
@ -153,10 +165,7 @@ func (c Client) ListenBucketNotification(bucketName string, accountArn Arn, done
queryValues: urlValues,
})
if err != nil {
notificationInfoCh <- NotificationInfo{
Err: err,
}
return
continue
}
// Validate http response, upon error return quickly.
@ -178,10 +187,7 @@ func (c Client) ListenBucketNotification(bucketName string, accountArn Arn, done
for bio.Scan() {
var notificationInfo NotificationInfo
if err = json.Unmarshal(bio.Bytes(), &notificationInfo); err != nil {
notificationInfoCh <- NotificationInfo{
Err: err,
}
return
continue
}
// Send notifications on channel only if there are events received.
if len(notificationInfo.Records) > 0 {
@ -198,12 +204,7 @@ func (c Client) ListenBucketNotification(bucketName string, accountArn Arn, done
// and re-connect.
if err == io.ErrUnexpectedEOF {
resp.Body.Close()
continue
}
notificationInfoCh <- NotificationInfo{
Err: err,
}
return
}
}
}(notificationInfoCh)

View File

@ -20,6 +20,9 @@ import (
"errors"
"net/url"
"time"
"github.com/minio/minio-go/pkg/s3signer"
"github.com/minio/minio-go/pkg/s3utils"
)
// supportedGetReqParams - supported request parameters for GET presigned request.
@ -126,14 +129,14 @@ func (c Client) PresignedPostPolicy(p *PostPolicy) (u *url.URL, formData map[str
policyBase64 := p.base64()
p.formData["policy"] = policyBase64
// For Google endpoint set this value to be 'GoogleAccessId'.
if isGoogleEndpoint(c.endpointURL) {
if s3utils.IsGoogleEndpoint(c.endpointURL) {
p.formData["GoogleAccessId"] = c.accessKeyID
} else {
// For all other endpoints set this value to be 'AWSAccessKeyId'.
p.formData["AWSAccessKeyId"] = c.accessKeyID
}
// Sign the policy.
p.formData["signature"] = postPresignSignatureV2(policyBase64, c.secretAccessKey)
p.formData["signature"] = s3signer.PostPresignSignatureV2(policyBase64, c.secretAccessKey)
return u, p.formData, nil
}
@ -156,7 +159,7 @@ func (c Client) PresignedPostPolicy(p *PostPolicy) (u *url.URL, formData map[str
}
// Add a credential policy.
credential := getCredential(c.accessKeyID, location, t)
credential := s3signer.GetCredential(c.accessKeyID, location, t)
if err = p.addNewPolicy(policyCondition{
matchType: "eq",
condition: "$x-amz-credential",
@ -172,6 +175,6 @@ func (c Client) PresignedPostPolicy(p *PostPolicy) (u *url.URL, formData map[str
p.formData["x-amz-algorithm"] = signV4Algorithm
p.formData["x-amz-credential"] = credential
p.formData["x-amz-date"] = t.Format(iso8601DateFormat)
p.formData["x-amz-signature"] = postPresignSignatureV4(policyBase64, t, c.secretAccessKey, location)
p.formData["x-amz-signature"] = s3signer.PostPresignSignatureV4(policyBase64, t, c.secretAccessKey, location)
return u, p.formData, nil
}

View File

@ -26,8 +26,10 @@ import (
"io/ioutil"
"net/http"
"net/url"
"path"
"github.com/minio/minio-go/pkg/policy"
"github.com/minio/minio-go/pkg/s3signer"
)
/// Bucket operations
@ -89,11 +91,8 @@ func (c Client) makeBucketRequest(bucketName string, location string) (*http.Req
// is the preferred method here. The final location of the
// 'bucket' is provided through XML LocationConstraint data with
// the request.
targetURL, err := url.Parse(c.endpointURL)
if err != nil {
return nil, err
}
targetURL.Path = "/" + bucketName + "/"
targetURL := c.endpointURL
targetURL.Path = path.Join(bucketName, "") + "/"
// get a new HTTP request for the method.
req, err := http.NewRequest("PUT", targetURL.String(), nil)
@ -133,9 +132,9 @@ func (c Client) makeBucketRequest(bucketName string, location string) (*http.Req
if c.signature.isV4() {
// Signature calculated for MakeBucket request should be for 'us-east-1',
// regardless of the bucket's location constraint.
req = signV4(*req, c.accessKeyID, c.secretAccessKey, "us-east-1")
req = s3signer.SignV4(*req, c.accessKeyID, c.secretAccessKey, "us-east-1")
} else if c.signature.isV2() {
req = signV2(*req, c.accessKeyID, c.secretAccessKey)
req = s3signer.SignV2(*req, c.accessKeyID, c.secretAccessKey)
}
// Return signed request.

View File

@ -24,8 +24,10 @@ import (
"io"
"io/ioutil"
"net/http"
"net/url"
"path"
"testing"
"github.com/minio/minio-go/pkg/s3signer"
)
// Tests validate http request formulated for creation of bucket.
@ -33,14 +35,11 @@ func TestMakeBucketRequest(t *testing.T) {
// Generates expected http request for bucket creation.
// Used for asserting with the actual request generated.
createExpectedRequest := func(c *Client, bucketName string, location string, req *http.Request) (*http.Request, error) {
targetURL, err := url.Parse(c.endpointURL)
if err != nil {
return nil, err
}
targetURL.Path = "/" + bucketName + "/"
targetURL := c.endpointURL
targetURL.Path = path.Join(bucketName, "") + "/"
// get a new HTTP request for the method.
var err error
req, err = http.NewRequest("PUT", targetURL.String(), nil)
if err != nil {
return nil, err
@ -78,9 +77,9 @@ func TestMakeBucketRequest(t *testing.T) {
if c.signature.isV4() {
// Signature calculated for MakeBucket request should be for 'us-east-1',
// regardless of the bucket's location constraint.
req = signV4(*req, c.accessKeyID, c.secretAccessKey, "us-east-1")
req = s3signer.SignV4(*req, c.accessKeyID, c.secretAccessKey, "us-east-1")
} else if c.signature.isV2() {
req = signV2(*req, c.accessKeyID, c.secretAccessKey)
req = s3signer.SignV2(*req, c.accessKeyID, c.secretAccessKey)
}
// Return signed request.

View File

@ -44,18 +44,17 @@ func isReadAt(reader io.Reader) (ok bool) {
}
// shouldUploadPart - verify if part should be uploaded.
func shouldUploadPart(objPart objectPart, objectParts map[int]objectPart) bool {
func shouldUploadPart(objPart objectPart, uploadReq uploadPartReq) bool {
// If part not found should upload the part.
uploadedPart, found := objectParts[objPart.PartNumber]
if !found {
if uploadReq.Part == nil {
return true
}
// if size mismatches should upload the part.
if objPart.Size != uploadedPart.Size {
if objPart.Size != uploadReq.Part.Size {
return true
}
// if md5sum mismatches should upload the part.
if objPart.ETag != uploadedPart.ETag {
if objPart.ETag != uploadReq.Part.ETag {
return true
}
return false
@ -68,7 +67,7 @@ func shouldUploadPart(objPart objectPart, objectParts map[int]objectPart) bool {
// object storage it will have the following parameters as constants.
//
// maxPartsCount - 10000
// minPartSize - 5MiB
// minPartSize - 64MiB
// maxMultipartPutObjectSize - 5TiB
//
func optimalPartInfo(objectSize int64) (totalPartsCount int, partSize int64, lastPartSize int64, err error) {
@ -167,37 +166,64 @@ func hashCopyN(hashAlgorithms map[string]hash.Hash, hashSums map[string][]byte,
// getUploadID - fetch upload id if already present for an object name
// or initiate a new request to fetch a new upload id.
func (c Client) getUploadID(bucketName, objectName, contentType string) (uploadID string, isNew bool, err error) {
func (c Client) newUploadID(bucketName, objectName string, metaData map[string][]string) (uploadID string, err error) {
// Input validation.
if err := isValidBucketName(bucketName); err != nil {
return "", false, err
return "", err
}
if err := isValidObjectName(objectName); err != nil {
return "", false, err
return "", err
}
// Set content Type to default if empty string.
if contentType == "" {
contentType = "application/octet-stream"
}
// Find upload id for previous upload for an object.
uploadID, err = c.findUploadID(bucketName, objectName)
// Initiate multipart upload for an object.
initMultipartUploadResult, err := c.initiateMultipartUpload(bucketName, objectName, metaData)
if err != nil {
return "", false, err
return "", err
}
return initMultipartUploadResult.UploadID, nil
}
// getMpartUploadSession returns the upload id and the uploaded parts to continue a previous upload session
// or initiate a new multipart session if no current one found
func (c Client) getMpartUploadSession(bucketName, objectName string, metaData map[string][]string) (string, map[int]objectPart, error) {
// A map of all uploaded parts.
var partsInfo map[int]objectPart
var err error
uploadID, err := c.findUploadID(bucketName, objectName)
if err != nil {
return "", nil, err
}
if uploadID == "" {
// Initiate multipart upload for an object.
initMultipartUploadResult, err := c.initiateMultipartUpload(bucketName, objectName, contentType)
// Initiates a new multipart request
uploadID, err = c.newUploadID(bucketName, objectName, metaData)
if err != nil {
return "", false, err
return "", nil, err
}
} else {
// Fetch previously upload parts and maximum part size.
partsInfo, err = c.listObjectParts(bucketName, objectName, uploadID)
if err != nil {
// When the server returns NoSuchUpload even if its previouls acknowleged the existance of the upload id,
// initiate a new multipart upload
if respErr, ok := err.(ErrorResponse); ok && respErr.Code == "NoSuchUpload" {
uploadID, err = c.newUploadID(bucketName, objectName, metaData)
if err != nil {
return "", nil, err
}
} else {
return "", nil, err
}
}
// Save the new upload id.
uploadID = initMultipartUploadResult.UploadID
// Indicate that this is a new upload id.
isNew = true
}
return uploadID, isNew, nil
// Allocate partsInfo if not done yet
if partsInfo == nil {
partsInfo = make(map[int]objectPart)
}
return uploadID, partsInfo, nil
}
// computeHash - Calculates hashes for an input read Seeker.

View File

@ -16,7 +16,11 @@
package minio
import "net/http"
import (
"net/http"
"github.com/minio/minio-go/pkg/s3utils"
)
// CopyObject - copy a source object into a new object with the provided name in the provided bucket
func (c Client) CopyObject(bucketName string, objectName string, objectSource string, cpCond CopyConditions) error {
@ -38,7 +42,7 @@ func (c Client) CopyObject(bucketName string, objectName string, objectSource st
}
// Set copy source.
customHeaders.Set("x-amz-copy-source", urlEncodePath(objectSource))
customHeaders.Set("x-amz-copy-source", s3utils.EncodePath(objectSource))
// Execute PUT on objectName.
resp, err := c.executeMethod("PUT", requestMetadata{

View File

@ -28,6 +28,8 @@ import (
"os"
"path/filepath"
"sort"
"github.com/minio/minio-go/pkg/s3utils"
)
// FPutObject - Create an object in a bucket, with contents from file at filePath.
@ -62,6 +64,8 @@ func (c Client) FPutObject(bucketName, objectName, filePath, contentType string)
return 0, ErrEntityTooLarge(fileSize, maxMultipartPutObjectSize, bucketName, objectName)
}
objMetadata := make(map[string][]string)
// Set contentType based on filepath extension if not given or default
// value of "binary/octet-stream" if the extension has no associated type.
if contentType == "" {
@ -70,9 +74,11 @@ func (c Client) FPutObject(bucketName, objectName, filePath, contentType string)
}
}
objMetadata["Content-Type"] = []string{contentType}
// NOTE: Google Cloud Storage multipart Put is not compatible with Amazon S3 APIs.
// Current implementation will only upload a maximum of 5GiB to Google Cloud Storage servers.
if isGoogleEndpoint(c.endpointURL) {
if s3utils.IsGoogleEndpoint(c.endpointURL) {
if fileSize > int64(maxSinglePutObjectSize) {
return 0, ErrorResponse{
Code: "NotImplemented",
@ -82,11 +88,11 @@ func (c Client) FPutObject(bucketName, objectName, filePath, contentType string)
}
}
// Do not compute MD5 for Google Cloud Storage. Uploads up to 5GiB in size.
return c.putObjectNoChecksum(bucketName, objectName, fileReader, fileSize, contentType, nil)
return c.putObjectNoChecksum(bucketName, objectName, fileReader, fileSize, objMetadata, nil)
}
// NOTE: S3 doesn't allow anonymous multipart requests.
if isAmazonEndpoint(c.endpointURL) && c.anonymous {
if s3utils.IsAmazonEndpoint(c.endpointURL) && c.anonymous {
if fileSize > int64(maxSinglePutObjectSize) {
return 0, ErrorResponse{
Code: "NotImplemented",
@ -97,15 +103,15 @@ func (c Client) FPutObject(bucketName, objectName, filePath, contentType string)
}
// Do not compute MD5 for anonymous requests to Amazon
// S3. Uploads up to 5GiB in size.
return c.putObjectNoChecksum(bucketName, objectName, fileReader, fileSize, contentType, nil)
return c.putObjectNoChecksum(bucketName, objectName, fileReader, fileSize, objMetadata, nil)
}
// Small object upload is initiated for uploads for input data size smaller than 5MiB.
if fileSize < minPartSize && fileSize >= 0 {
return c.putObjectSingle(bucketName, objectName, fileReader, fileSize, contentType, nil)
return c.putObjectSingle(bucketName, objectName, fileReader, fileSize, objMetadata, nil)
}
// Upload all large objects as multipart.
n, err = c.putObjectMultipartFromFile(bucketName, objectName, fileReader, fileSize, contentType, nil)
n, err = c.putObjectMultipartFromFile(bucketName, objectName, fileReader, fileSize, objMetadata, nil)
if err != nil {
errResp := ToErrorResponse(err)
// Verify if multipart functionality is not available, if not
@ -116,7 +122,7 @@ func (c Client) FPutObject(bucketName, objectName, filePath, contentType string)
return 0, ErrEntityTooLarge(fileSize, maxSinglePutObjectSize, bucketName, objectName)
}
// Fall back to uploading as single PutObject operation.
return c.putObjectSingle(bucketName, objectName, fileReader, fileSize, contentType, nil)
return c.putObjectSingle(bucketName, objectName, fileReader, fileSize, objMetadata, nil)
}
return n, err
}
@ -131,7 +137,7 @@ func (c Client) FPutObject(bucketName, objectName, filePath, contentType string)
// against MD5SUM of each individual parts. This function also
// effectively utilizes file system capabilities of reading from
// specific sections and not having to create temporary files.
func (c Client) putObjectMultipartFromFile(bucketName, objectName string, fileReader io.ReaderAt, fileSize int64, contentType string, progress io.Reader) (int64, error) {
func (c Client) putObjectMultipartFromFile(bucketName, objectName string, fileReader io.ReaderAt, fileSize int64, metaData map[string][]string, progress io.Reader) (int64, error) {
// Input validation.
if err := isValidBucketName(bucketName); err != nil {
return 0, err
@ -140,9 +146,8 @@ func (c Client) putObjectMultipartFromFile(bucketName, objectName string, fileRe
return 0, err
}
// Get upload id for an object, initiates a new multipart request
// if it cannot find any previously partially uploaded object.
uploadID, isNew, err := c.getUploadID(bucketName, objectName, contentType)
// Get the upload id of a previously partially uploaded object or initiate a new multipart upload
uploadID, partsInfo, err := c.getMpartUploadSession(bucketName, objectName, metaData)
if err != nil {
return 0, err
}
@ -151,83 +156,139 @@ func (c Client) putObjectMultipartFromFile(bucketName, objectName string, fileRe
var totalUploadedSize int64
// Complete multipart upload.
var completeMultipartUpload completeMultipartUpload
// A map of all uploaded parts.
var partsInfo = make(map[int]objectPart)
// If this session is a continuation of a previous session fetch all
// previously uploaded parts info.
if !isNew {
// Fetch previously upload parts and maximum part size.
partsInfo, err = c.listObjectParts(bucketName, objectName, uploadID)
if err != nil {
return 0, err
}
}
var complMultipartUpload completeMultipartUpload
// Calculate the optimal parts info for a given size.
totalPartsCount, partSize, _, err := optimalPartInfo(fileSize)
totalPartsCount, partSize, lastPartSize, err := optimalPartInfo(fileSize)
if err != nil {
return 0, err
}
// Part number always starts with '1'.
partNumber := 1
// Create a channel to communicate a part was uploaded.
// Buffer this to 10000, the maximum number of parts allowed by S3.
uploadedPartsCh := make(chan uploadedPartRes, 10000)
for partNumber <= totalPartsCount {
// Get a section reader on a particular offset.
sectionReader := io.NewSectionReader(fileReader, totalUploadedSize, partSize)
// Create a channel to communicate which part to upload.
// Buffer this to 10000, the maximum number of parts allowed by S3.
uploadPartsCh := make(chan uploadPartReq, 10000)
// Add hash algorithms that need to be calculated by computeHash()
// In case of a non-v4 signature or https connection, sha256 is not needed.
hashAlgos := make(map[string]hash.Hash)
hashSums := make(map[string][]byte)
hashAlgos["md5"] = md5.New()
if c.signature.isV4() && !c.secure {
hashAlgos["sha256"] = sha256.New()
}
// Just for readability.
lastPartNumber := totalPartsCount
var prtSize int64
prtSize, err = computeHash(hashAlgos, hashSums, sectionReader)
if err != nil {
return 0, err
}
var reader io.Reader
// Update progress reader appropriately to the latest offset
// as we read from the source.
reader = newHook(sectionReader, progress)
// Verify if part should be uploaded.
if shouldUploadPart(objectPart{
ETag: hex.EncodeToString(hashSums["md5"]),
PartNumber: partNumber,
Size: prtSize,
}, partsInfo) {
// Proceed to upload the part.
var objPart objectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID, reader, partNumber,
hashSums["md5"], hashSums["sha256"], prtSize)
if err != nil {
return totalUploadedSize, err
}
// Save successfully uploaded part metadata.
partsInfo[partNumber] = objPart
// Send each part through the partUploadCh to be uploaded.
for p := 1; p <= totalPartsCount; p++ {
part, ok := partsInfo[p]
if ok {
uploadPartsCh <- uploadPartReq{PartNum: p, Part: &part}
} else {
// Update the progress reader for the skipped part.
if progress != nil {
if _, err = io.CopyN(ioutil.Discard, progress, prtSize); err != nil {
return totalUploadedSize, err
uploadPartsCh <- uploadPartReq{PartNum: p, Part: nil}
}
}
close(uploadPartsCh)
// Use three 'workers' to upload parts in parallel.
for w := 1; w <= 3; w++ {
go func() {
// Deal with each part as it comes through the channel.
for uploadReq := range uploadPartsCh {
// Add hash algorithms that need to be calculated by computeHash()
// In case of a non-v4 signature or https connection, sha256 is not needed.
hashAlgos := make(map[string]hash.Hash)
hashSums := make(map[string][]byte)
hashAlgos["md5"] = md5.New()
if c.signature.isV4() && !c.secure {
hashAlgos["sha256"] = sha256.New()
}
// If partNumber was not uploaded we calculate the missing
// part offset and size. For all other part numbers we
// calculate offset based on multiples of partSize.
readOffset := int64(uploadReq.PartNum-1) * partSize
missingPartSize := partSize
// As a special case if partNumber is lastPartNumber, we
// calculate the offset based on the last part size.
if uploadReq.PartNum == lastPartNumber {
readOffset = (fileSize - lastPartSize)
missingPartSize = lastPartSize
}
// Get a section reader on a particular offset.
sectionReader := io.NewSectionReader(fileReader, readOffset, missingPartSize)
var prtSize int64
var err error
prtSize, err = computeHash(hashAlgos, hashSums, sectionReader)
if err != nil {
uploadedPartsCh <- uploadedPartRes{
Error: err,
}
// Exit the goroutine.
return
}
// Create the part to be uploaded.
verifyObjPart := objectPart{
ETag: hex.EncodeToString(hashSums["md5"]),
PartNumber: uploadReq.PartNum,
Size: partSize,
}
// If this is the last part do not give it the full part size.
if uploadReq.PartNum == lastPartNumber {
verifyObjPart.Size = lastPartSize
}
// Verify if part should be uploaded.
if shouldUploadPart(verifyObjPart, uploadReq) {
// Proceed to upload the part.
var objPart objectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID, sectionReader, uploadReq.PartNum, hashSums["md5"], hashSums["sha256"], prtSize)
if err != nil {
uploadedPartsCh <- uploadedPartRes{
Error: err,
}
// Exit the goroutine.
return
}
// Save successfully uploaded part metadata.
uploadReq.Part = &objPart
}
// Return through the channel the part size.
uploadedPartsCh <- uploadedPartRes{
Size: verifyObjPart.Size,
PartNum: uploadReq.PartNum,
Part: uploadReq.Part,
Error: nil,
}
}
}()
}
// Retrieve each uploaded part once it is done.
for u := 1; u <= totalPartsCount; u++ {
uploadRes := <-uploadedPartsCh
if uploadRes.Error != nil {
return totalUploadedSize, uploadRes.Error
}
// Save successfully uploaded size.
totalUploadedSize += prtSize
// Increment part number.
partNumber++
// Retrieve each uploaded part and store it to be completed.
part := uploadRes.Part
if part == nil {
return totalUploadedSize, ErrInvalidArgument(fmt.Sprintf("Missing part number %d", uploadRes.PartNum))
}
// Update the total uploaded size.
totalUploadedSize += uploadRes.Size
// Update the progress bar if there is one.
if progress != nil {
if _, err = io.CopyN(ioutil.Discard, progress, uploadRes.Size); err != nil {
return totalUploadedSize, err
}
}
// Store the part to be completed.
complMultipartUpload.Parts = append(complMultipartUpload.Parts, completePart{
ETag: part.ETag,
PartNumber: part.PartNumber,
})
}
// Verify if we uploaded all data.
@ -235,22 +296,9 @@ func (c Client) putObjectMultipartFromFile(bucketName, objectName string, fileRe
return totalUploadedSize, ErrUnexpectedEOF(totalUploadedSize, fileSize, bucketName, objectName)
}
// Loop over uploaded parts to save them in a Parts array before completing the multipart request.
for _, part := range partsInfo {
var complPart completePart
complPart.ETag = part.ETag
complPart.PartNumber = part.PartNumber
completeMultipartUpload.Parts = append(completeMultipartUpload.Parts, complPart)
}
// Verify if totalPartsCount is not equal to total list of parts.
if totalPartsCount != len(completeMultipartUpload.Parts) {
return totalUploadedSize, ErrInvalidParts(partNumber, len(completeMultipartUpload.Parts))
}
// Sort all completed parts.
sort.Sort(completedParts(completeMultipartUpload.Parts))
_, err = c.completeMultipartUpload(bucketName, objectName, uploadID, completeMultipartUpload)
sort.Sort(completedParts(complMultipartUpload.Parts))
_, err = c.completeMultipartUpload(bucketName, objectName, uploadID, complMultipartUpload)
if err != nil {
return totalUploadedSize, err
}

View File

@ -22,6 +22,7 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/xml"
"fmt"
"hash"
"io"
"io/ioutil"
@ -44,11 +45,11 @@ import (
// If we exhaust all the known types, code proceeds to use stream as
// is where each part is re-downloaded, checksummed and verified
// before upload.
func (c Client) putObjectMultipart(bucketName, objectName string, reader io.Reader, size int64, contentType string, progress io.Reader) (n int64, err error) {
func (c Client) putObjectMultipart(bucketName, objectName string, reader io.Reader, size int64, metaData map[string][]string, progress io.Reader) (n int64, err error) {
if size > 0 && size > minPartSize {
// Verify if reader is *os.File, then use file system functionalities.
if isFile(reader) {
return c.putObjectMultipartFromFile(bucketName, objectName, reader.(*os.File), size, contentType, progress)
return c.putObjectMultipartFromFile(bucketName, objectName, reader.(*os.File), size, metaData, progress)
}
// Verify if reader is *minio.Object or io.ReaderAt.
// NOTE: Verification of object is kept for a specific purpose
@ -57,17 +58,17 @@ func (c Client) putObjectMultipart(bucketName, objectName string, reader io.Read
// and such a functionality is used in the subsequent code
// path.
if isObject(reader) || isReadAt(reader) {
return c.putObjectMultipartFromReadAt(bucketName, objectName, reader.(io.ReaderAt), size, contentType, progress)
return c.putObjectMultipartFromReadAt(bucketName, objectName, reader.(io.ReaderAt), size, metaData, progress)
}
}
// For any other data size and reader type we do generic multipart
// approach by staging data in temporary files and uploading them.
return c.putObjectMultipartStream(bucketName, objectName, reader, size, contentType, progress)
return c.putObjectMultipartStream(bucketName, objectName, reader, size, metaData, progress)
}
// putObjectStream uploads files bigger than 5MiB, and also supports
// putObjectStream uploads files bigger than 64MiB, and also supports
// special case where size is unknown i.e '-1'.
func (c Client) putObjectMultipartStream(bucketName, objectName string, reader io.Reader, size int64, contentType string, progress io.Reader) (n int64, err error) {
func (c Client) putObjectMultipartStream(bucketName, objectName string, reader io.Reader, size int64, metaData map[string][]string, progress io.Reader) (n int64, err error) {
// Input validation.
if err := isValidBucketName(bucketName); err != nil {
return 0, err
@ -82,26 +83,12 @@ func (c Client) putObjectMultipartStream(bucketName, objectName string, reader i
// Complete multipart upload.
var complMultipartUpload completeMultipartUpload
// A map of all previously uploaded parts.
var partsInfo = make(map[int]objectPart)
// getUploadID for an object, initiates a new multipart request
// if it cannot find any previously partially uploaded object.
uploadID, isNew, err := c.getUploadID(bucketName, objectName, contentType)
// Get the upload id of a previously partially uploaded object or initiate a new multipart upload
uploadID, partsInfo, err := c.getMpartUploadSession(bucketName, objectName, metaData)
if err != nil {
return 0, err
}
// If This session is a continuation of a previous session fetch all
// previously uploaded parts info.
if !isNew {
// Fetch previously uploaded parts and maximum part size.
partsInfo, err = c.listObjectParts(bucketName, objectName, uploadID)
if err != nil {
return 0, err
}
}
// Calculate the optimal parts info for a given size.
totalPartsCount, partSize, _, err := optimalPartInfo(size)
if err != nil {
@ -115,7 +102,6 @@ func (c Client) putObjectMultipartStream(bucketName, objectName string, reader i
tmpBuffer := new(bytes.Buffer)
for partNumber <= totalPartsCount {
// Choose hash algorithms to be calculated by hashCopyN, avoid sha256
// with non-v4 signature request or HTTPS connection
hashSums := make(map[string][]byte)
@ -138,12 +124,14 @@ func (c Client) putObjectMultipartStream(bucketName, objectName string, reader i
// as we read from the source.
reader = newHook(tmpBuffer, progress)
part, ok := partsInfo[partNumber]
// Verify if part should be uploaded.
if shouldUploadPart(objectPart{
if !ok || shouldUploadPart(objectPart{
ETag: hex.EncodeToString(hashSums["md5"]),
PartNumber: partNumber,
Size: prtSize,
}, partsInfo) {
}, uploadPartReq{PartNum: partNumber, Part: &part}) {
// Proceed to upload the part.
var objPart objectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID, reader, partNumber, hashSums["md5"], hashSums["sha256"], prtSize)
@ -169,14 +157,14 @@ func (c Client) putObjectMultipartStream(bucketName, objectName string, reader i
// Save successfully uploaded size.
totalUploadedSize += prtSize
// Increment part number.
partNumber++
// For unknown size, Read EOF we break away.
// We do not have to upload till totalPartsCount.
if size < 0 && rErr == io.EOF {
break
}
// Increment part number.
partNumber++
}
// Verify if we uploaded all the data.
@ -186,19 +174,17 @@ func (c Client) putObjectMultipartStream(bucketName, objectName string, reader i
}
}
// Loop over uploaded parts to save them in a Parts array before completing the multipart request.
for _, part := range partsInfo {
var complPart completePart
complPart.ETag = part.ETag
complPart.PartNumber = part.PartNumber
complMultipartUpload.Parts = append(complMultipartUpload.Parts, complPart)
}
if size > 0 {
// Verify if totalPartsCount is not equal to total list of parts.
if totalPartsCount != len(complMultipartUpload.Parts) {
return totalUploadedSize, ErrInvalidParts(partNumber, len(complMultipartUpload.Parts))
// Loop over total uploaded parts to save them in
// Parts array before completing the multipart request.
for i := 1; i < partNumber; i++ {
part, ok := partsInfo[i]
if !ok {
return 0, ErrInvalidArgument(fmt.Sprintf("Missing part number %d", i))
}
complMultipartUpload.Parts = append(complMultipartUpload.Parts, completePart{
ETag: part.ETag,
PartNumber: part.PartNumber,
})
}
// Sort all completed parts.
@ -213,7 +199,7 @@ func (c Client) putObjectMultipartStream(bucketName, objectName string, reader i
}
// initiateMultipartUpload - Initiates a multipart upload and returns an upload ID.
func (c Client) initiateMultipartUpload(bucketName, objectName, contentType string) (initiateMultipartUploadResult, error) {
func (c Client) initiateMultipartUpload(bucketName, objectName string, metaData map[string][]string) (initiateMultipartUploadResult, error) {
// Input validation.
if err := isValidBucketName(bucketName); err != nil {
return initiateMultipartUploadResult{}, err
@ -226,13 +212,18 @@ func (c Client) initiateMultipartUpload(bucketName, objectName, contentType stri
urlValues := make(url.Values)
urlValues.Set("uploads", "")
if contentType == "" {
contentType = "application/octet-stream"
}
// Set ContentType header.
customHeader := make(http.Header)
customHeader.Set("Content-Type", contentType)
for k, v := range metaData {
if len(v) > 0 {
customHeader.Set(k, v[0])
}
}
// Set a default content-type header if the latter is not provided
if v, ok := metaData["Content-Type"]; !ok || len(v) == 0 {
customHeader.Set("Content-Type", "application/octet-stream")
}
reqMetadata := requestMetadata{
bucketName: bucketName,

View File

@ -16,10 +16,22 @@
package minio
import "io"
import (
"io"
"strings"
// PutObjectWithProgress - With progress.
"github.com/minio/minio-go/pkg/s3utils"
)
// PutObjectWithProgress - with progress.
func (c Client) PutObjectWithProgress(bucketName, objectName string, reader io.Reader, contentType string, progress io.Reader) (n int64, err error) {
metaData := make(map[string][]string)
metaData["Content-Type"] = []string{contentType}
return c.PutObjectWithMetadata(bucketName, objectName, reader, metaData, progress)
}
// PutObjectWithMetadata - with metadata.
func (c Client) PutObjectWithMetadata(bucketName, objectName string, reader io.Reader, metaData map[string][]string, progress io.Reader) (n int64, err error) {
// Input validation.
if err := isValidBucketName(bucketName); err != nil {
return 0, err
@ -47,7 +59,7 @@ func (c Client) PutObjectWithProgress(bucketName, objectName string, reader io.R
// NOTE: Google Cloud Storage does not implement Amazon S3 Compatible multipart PUT.
// So we fall back to single PUT operation with the maximum limit of 5GiB.
if isGoogleEndpoint(c.endpointURL) {
if s3utils.IsGoogleEndpoint(c.endpointURL) {
if size <= -1 {
return 0, ErrorResponse{
Code: "NotImplemented",
@ -60,11 +72,11 @@ func (c Client) PutObjectWithProgress(bucketName, objectName string, reader io.R
return 0, ErrEntityTooLarge(size, maxSinglePutObjectSize, bucketName, objectName)
}
// Do not compute MD5 for Google Cloud Storage. Uploads up to 5GiB in size.
return c.putObjectNoChecksum(bucketName, objectName, reader, size, contentType, progress)
return c.putObjectNoChecksum(bucketName, objectName, reader, size, metaData, progress)
}
// NOTE: S3 doesn't allow anonymous multipart requests.
if isAmazonEndpoint(c.endpointURL) && c.anonymous {
if s3utils.IsAmazonEndpoint(c.endpointURL) && c.anonymous {
if size <= -1 {
return 0, ErrorResponse{
Code: "NotImplemented",
@ -78,26 +90,26 @@ func (c Client) PutObjectWithProgress(bucketName, objectName string, reader io.R
}
// Do not compute MD5 for anonymous requests to Amazon
// S3. Uploads up to 5GiB in size.
return c.putObjectNoChecksum(bucketName, objectName, reader, size, contentType, progress)
return c.putObjectNoChecksum(bucketName, objectName, reader, size, metaData, progress)
}
// putSmall object.
if size < minPartSize && size >= 0 {
return c.putObjectSingle(bucketName, objectName, reader, size, contentType, progress)
return c.putObjectSingle(bucketName, objectName, reader, size, metaData, progress)
}
// For all sizes greater than 5MiB do multipart.
n, err = c.putObjectMultipart(bucketName, objectName, reader, size, contentType, progress)
n, err = c.putObjectMultipart(bucketName, objectName, reader, size, metaData, progress)
if err != nil {
errResp := ToErrorResponse(err)
// Verify if multipart functionality is not available, if not
// fall back to single PutObject operation.
if errResp.Code == "AccessDenied" && errResp.Message == "Access Denied." {
if errResp.Code == "AccessDenied" && strings.Contains(errResp.Message, "Access Denied") {
// Verify if size of reader is greater than '5GiB'.
if size > maxSinglePutObjectSize {
return 0, ErrEntityTooLarge(size, maxSinglePutObjectSize, bucketName, objectName)
}
// Fall back to uploading as single PutObject operation.
return c.putObjectSingle(bucketName, objectName, reader, size, contentType, progress)
return c.putObjectSingle(bucketName, objectName, reader, size, metaData, progress)
}
return n, err
}

View File

@ -20,21 +20,34 @@ import (
"bytes"
"crypto/md5"
"crypto/sha256"
"fmt"
"hash"
"io"
"io/ioutil"
"sort"
)
// uploadedPartRes - the response received from a part upload.
type uploadedPartRes struct {
Error error // Any error encountered while uploading the part.
PartNum int // Number of the part uploaded.
Size int64 // Size of the part uploaded.
Part *objectPart
}
type uploadPartReq struct {
PartNum int // Number of the part uploaded.
Part *objectPart // Size of the part uploaded.
}
// shouldUploadPartReadAt - verify if part should be uploaded.
func shouldUploadPartReadAt(objPart objectPart, objectParts map[int]objectPart) bool {
func shouldUploadPartReadAt(objPart objectPart, uploadReq uploadPartReq) bool {
// If part not found part should be uploaded.
uploadedPart, found := objectParts[objPart.PartNumber]
if !found {
if uploadReq.Part == nil {
return true
}
// if size mismatches part should be uploaded.
if uploadedPart.Size != objPart.Size {
if uploadReq.Part.Size != objPart.Size {
return true
}
return false
@ -50,7 +63,7 @@ func shouldUploadPartReadAt(objPart objectPart, objectParts map[int]objectPart)
// temporary files for staging all the data, these temporary files are
// cleaned automatically when the caller i.e http client closes the
// stream after uploading all the contents successfully.
func (c Client) putObjectMultipartFromReadAt(bucketName, objectName string, reader io.ReaderAt, size int64, contentType string, progress io.Reader) (n int64, err error) {
func (c Client) putObjectMultipartFromReadAt(bucketName, objectName string, reader io.ReaderAt, size int64, metaData map[string][]string, progress io.Reader) (n int64, err error) {
// Input validation.
if err := isValidBucketName(bucketName); err != nil {
return 0, err
@ -59,9 +72,8 @@ func (c Client) putObjectMultipartFromReadAt(bucketName, objectName string, read
return 0, err
}
// Get upload id for an object, initiates a new multipart request
// if it cannot find any previously partially uploaded object.
uploadID, isNew, err := c.getUploadID(bucketName, objectName, contentType)
// Get the upload id of a previously partially uploaded object or initiate a new multipart upload
uploadID, partsInfo, err := c.getMpartUploadSession(bucketName, objectName, metaData)
if err != nil {
return 0, err
}
@ -72,127 +84,150 @@ func (c Client) putObjectMultipartFromReadAt(bucketName, objectName string, read
// Complete multipart upload.
var complMultipartUpload completeMultipartUpload
// A map of all uploaded parts.
var partsInfo = make(map[int]objectPart)
// Fetch all parts info previously uploaded.
if !isNew {
partsInfo, err = c.listObjectParts(bucketName, objectName, uploadID)
if err != nil {
return 0, err
}
}
// Calculate the optimal parts info for a given size.
totalPartsCount, partSize, lastPartSize, err := optimalPartInfo(size)
if err != nil {
return 0, err
}
// Used for readability, lastPartNumber is always
// totalPartsCount.
// Used for readability, lastPartNumber is always totalPartsCount.
lastPartNumber := totalPartsCount
// partNumber always starts with '1'.
partNumber := 1
// Declare a channel that sends the next part number to be uploaded.
// Buffered to 10000 because thats the maximum number of parts allowed
// by S3.
uploadPartsCh := make(chan uploadPartReq, 10000)
// Initialize a temporary buffer.
tmpBuffer := new(bytes.Buffer)
// Declare a channel that sends back the response of a part upload.
// Buffered to 10000 because thats the maximum number of parts allowed
// by S3.
uploadedPartsCh := make(chan uploadedPartRes, 10000)
// Read defaults to reading at 5MiB buffer.
readAtBuffer := make([]byte, optimalReadBufferSize)
// Upload all the missing parts.
for partNumber <= lastPartNumber {
// Verify object if its uploaded.
verifyObjPart := objectPart{
PartNumber: partNumber,
Size: partSize,
}
// Special case if we see a last part number, save last part
// size as the proper part size.
if partNumber == lastPartNumber {
verifyObjPart = objectPart{
PartNumber: lastPartNumber,
Size: lastPartSize,
}
// Send each part number to the channel to be processed.
for p := 1; p <= totalPartsCount; p++ {
part, ok := partsInfo[p]
if ok {
uploadPartsCh <- uploadPartReq{PartNum: p, Part: &part}
} else {
uploadPartsCh <- uploadPartReq{PartNum: p, Part: nil}
}
}
close(uploadPartsCh)
// Verify if part should be uploaded.
if !shouldUploadPartReadAt(verifyObjPart, partsInfo) {
// Increment part number when not uploaded.
partNumber++
if progress != nil {
// Update the progress reader for the skipped part.
if _, err = io.CopyN(ioutil.Discard, progress, verifyObjPart.Size); err != nil {
return 0, err
// Receive each part number from the channel allowing three parallel uploads.
for w := 1; w <= 3; w++ {
go func() {
// Read defaults to reading at 5MiB buffer.
readAtBuffer := make([]byte, optimalReadBufferSize)
// Each worker will draw from the part channel and upload in parallel.
for uploadReq := range uploadPartsCh {
// Declare a new tmpBuffer.
tmpBuffer := new(bytes.Buffer)
// If partNumber was not uploaded we calculate the missing
// part offset and size. For all other part numbers we
// calculate offset based on multiples of partSize.
readOffset := int64(uploadReq.PartNum-1) * partSize
missingPartSize := partSize
// As a special case if partNumber is lastPartNumber, we
// calculate the offset based on the last part size.
if uploadReq.PartNum == lastPartNumber {
readOffset = (size - lastPartSize)
missingPartSize = lastPartSize
}
// Get a section reader on a particular offset.
sectionReader := io.NewSectionReader(reader, readOffset, missingPartSize)
// Choose the needed hash algorithms to be calculated by hashCopyBuffer.
// Sha256 is avoided in non-v4 signature requests or HTTPS connections
hashSums := make(map[string][]byte)
hashAlgos := make(map[string]hash.Hash)
hashAlgos["md5"] = md5.New()
if c.signature.isV4() && !c.secure {
hashAlgos["sha256"] = sha256.New()
}
var prtSize int64
var err error
prtSize, err = hashCopyBuffer(hashAlgos, hashSums, tmpBuffer, sectionReader, readAtBuffer)
if err != nil {
// Send the error back through the channel.
uploadedPartsCh <- uploadedPartRes{
Size: 0,
Error: err,
}
// Exit the goroutine.
return
}
// Verify object if its uploaded.
verifyObjPart := objectPart{
PartNumber: uploadReq.PartNum,
Size: partSize,
}
// Special case if we see a last part number, save last part
// size as the proper part size.
if uploadReq.PartNum == lastPartNumber {
verifyObjPart.Size = lastPartSize
}
// Only upload the necessary parts. Otherwise return size through channel
// to update any progress bar.
if shouldUploadPartReadAt(verifyObjPart, uploadReq) {
// Proceed to upload the part.
var objPart objectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID, tmpBuffer, uploadReq.PartNum, hashSums["md5"], hashSums["sha256"], prtSize)
if err != nil {
uploadedPartsCh <- uploadedPartRes{
Size: 0,
Error: err,
}
// Exit the goroutine.
return
}
// Save successfully uploaded part metadata.
uploadReq.Part = &objPart
}
// Send successful part info through the channel.
uploadedPartsCh <- uploadedPartRes{
Size: verifyObjPart.Size,
PartNum: uploadReq.PartNum,
Part: uploadReq.Part,
Error: nil,
}
}
continue
}
// If partNumber was not uploaded we calculate the missing
// part offset and size. For all other part numbers we
// calculate offset based on multiples of partSize.
readOffset := int64(partNumber-1) * partSize
missingPartSize := partSize
// As a special case if partNumber is lastPartNumber, we
// calculate the offset based on the last part size.
if partNumber == lastPartNumber {
readOffset = (size - lastPartSize)
missingPartSize = lastPartSize
}
// Get a section reader on a particular offset.
sectionReader := io.NewSectionReader(reader, readOffset, missingPartSize)
// Choose the needed hash algorithms to be calculated by hashCopyBuffer.
// Sha256 is avoided in non-v4 signature requests or HTTPS connections
hashSums := make(map[string][]byte)
hashAlgos := make(map[string]hash.Hash)
hashAlgos["md5"] = md5.New()
if c.signature.isV4() && !c.secure {
hashAlgos["sha256"] = sha256.New()
}
var prtSize int64
prtSize, err = hashCopyBuffer(hashAlgos, hashSums, tmpBuffer, sectionReader, readAtBuffer)
if err != nil {
return 0, err
}
var reader io.Reader
// Update progress reader appropriately to the latest offset
// as we read from the source.
reader = newHook(tmpBuffer, progress)
// Proceed to upload the part.
var objPart objectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID, reader, partNumber, hashSums["md5"], hashSums["sha256"], prtSize)
if err != nil {
// Reset the buffer upon any error.
tmpBuffer.Reset()
return 0, err
}
// Save successfully uploaded part metadata.
partsInfo[partNumber] = objPart
// Increment part number here after successful part upload.
partNumber++
// Reset the buffer.
tmpBuffer.Reset()
}()
}
// Loop over uploaded parts to save them in a Parts array before completing the multipart request.
for _, part := range partsInfo {
var complPart completePart
complPart.ETag = part.ETag
complPart.PartNumber = part.PartNumber
totalUploadedSize += part.Size
complMultipartUpload.Parts = append(complMultipartUpload.Parts, complPart)
// Gather the responses as they occur and update any
// progress bar.
for u := 1; u <= totalPartsCount; u++ {
uploadRes := <-uploadedPartsCh
if uploadRes.Error != nil {
return totalUploadedSize, uploadRes.Error
}
// Retrieve each uploaded part and store it to be completed.
// part, ok := partsInfo[uploadRes.PartNum]
part := uploadRes.Part
if part == nil {
return 0, ErrInvalidArgument(fmt.Sprintf("Missing part number %d", uploadRes.PartNum))
}
// Update the totalUploadedSize.
totalUploadedSize += uploadRes.Size
// Update the progress bar if there is one.
if progress != nil {
if _, err = io.CopyN(ioutil.Discard, progress, uploadRes.Size); err != nil {
return totalUploadedSize, err
}
}
// Store the parts to be completed in order.
complMultipartUpload.Parts = append(complMultipartUpload.Parts, completePart{
ETag: part.ETag,
PartNumber: part.PartNumber,
})
}
// Verify if we uploaded all the data.
@ -200,11 +235,6 @@ func (c Client) putObjectMultipartFromReadAt(bucketName, objectName string, read
return totalUploadedSize, ErrUnexpectedEOF(totalUploadedSize, size, bucketName, objectName)
}
// Verify if totalPartsCount is not equal to total list of parts.
if totalPartsCount != len(complMultipartUpload.Parts) {
return totalUploadedSize, ErrInvalidParts(totalPartsCount, len(complMultipartUpload.Parts))
}
// Sort all completed parts.
sort.Sort(completedParts(complMultipartUpload.Parts))
_, err = c.completeMultipartUpload(bucketName, objectName, uploadID, complMultipartUpload)

View File

@ -103,11 +103,10 @@ func getReaderSize(reader io.Reader) (size int64, err error) {
// implement Seekable calls. Ignore them and treat
// them like a stream with unknown length.
switch st.Name() {
case "stdin":
fallthrough
case "stdout":
fallthrough
case "stderr":
case "stdin", "stdout", "stderr":
return
// Ignore read/write stream of os.Pipe() which have unknown length too.
case "|0", "|1":
return
}
size = st.Size()
@ -151,7 +150,7 @@ func (c Client) PutObject(bucketName, objectName string, reader io.Reader, conte
// putObjectNoChecksum special function used Google Cloud Storage. This special function
// is used for Google Cloud Storage since Google's multipart API is not S3 compatible.
func (c Client) putObjectNoChecksum(bucketName, objectName string, reader io.Reader, size int64, contentType string, progress io.Reader) (n int64, err error) {
func (c Client) putObjectNoChecksum(bucketName, objectName string, reader io.Reader, size int64, metaData map[string][]string, progress io.Reader) (n int64, err error) {
// Input validation.
if err := isValidBucketName(bucketName); err != nil {
return 0, err
@ -169,7 +168,7 @@ func (c Client) putObjectNoChecksum(bucketName, objectName string, reader io.Rea
// This function does not calculate sha256 and md5sum for payload.
// Execute put object.
st, err := c.putObjectDo(bucketName, objectName, readSeeker, nil, nil, size, contentType)
st, err := c.putObjectDo(bucketName, objectName, readSeeker, nil, nil, size, metaData)
if err != nil {
return 0, err
}
@ -181,7 +180,7 @@ func (c Client) putObjectNoChecksum(bucketName, objectName string, reader io.Rea
// putObjectSingle is a special function for uploading single put object request.
// This special function is used as a fallback when multipart upload fails.
func (c Client) putObjectSingle(bucketName, objectName string, reader io.Reader, size int64, contentType string, progress io.Reader) (n int64, err error) {
func (c Client) putObjectSingle(bucketName, objectName string, reader io.Reader, size int64, metaData map[string][]string, progress io.Reader) (n int64, err error) {
// Input validation.
if err := isValidBucketName(bucketName); err != nil {
return 0, err
@ -221,6 +220,9 @@ func (c Client) putObjectSingle(bucketName, objectName string, reader io.Reader,
}
defer tmpFile.Close()
size, err = hashCopyN(hashAlgos, hashSums, tmpFile, reader, size)
if err != nil {
return 0, err
}
// Seek back to beginning of the temporary file.
if _, err = tmpFile.Seek(0, 0); err != nil {
return 0, err
@ -234,7 +236,7 @@ func (c Client) putObjectSingle(bucketName, objectName string, reader io.Reader,
}
}
// Execute put object.
st, err := c.putObjectDo(bucketName, objectName, reader, hashSums["md5"], hashSums["sha256"], size, contentType)
st, err := c.putObjectDo(bucketName, objectName, reader, hashSums["md5"], hashSums["sha256"], size, metaData)
if err != nil {
return 0, err
}
@ -252,7 +254,7 @@ func (c Client) putObjectSingle(bucketName, objectName string, reader io.Reader,
// putObjectDo - executes the put object http operation.
// NOTE: You must have WRITE permissions on a bucket to add an object to it.
func (c Client) putObjectDo(bucketName, objectName string, reader io.Reader, md5Sum []byte, sha256Sum []byte, size int64, contentType string) (ObjectInfo, error) {
func (c Client) putObjectDo(bucketName, objectName string, reader io.Reader, md5Sum []byte, sha256Sum []byte, size int64, metaData map[string][]string) (ObjectInfo, error) {
// Input validation.
if err := isValidBucketName(bucketName); err != nil {
return ObjectInfo{}, err
@ -269,13 +271,20 @@ func (c Client) putObjectDo(bucketName, objectName string, reader io.Reader, md5
return ObjectInfo{}, ErrEntityTooLarge(size, maxSinglePutObjectSize, bucketName, objectName)
}
if strings.TrimSpace(contentType) == "" {
contentType = "application/octet-stream"
}
// Set headers.
customHeader := make(http.Header)
customHeader.Set("Content-Type", contentType)
// Set metadata to headers
for k, v := range metaData {
if len(v) > 0 {
customHeader.Set(k, v[0])
}
}
// If Content-Type is not provided, set the default application/octet-stream one
if v, ok := metaData["Content-Type"]; !ok || len(v) == 0 {
customHeader.Set("Content-Type", "application/octet-stream")
}
// Populate request metadata.
reqMetadata := requestMetadata{
@ -300,13 +309,13 @@ func (c Client) putObjectDo(bucketName, objectName string, reader io.Reader, md5
}
}
var metadata ObjectInfo
var objInfo ObjectInfo
// Trim off the odd double quotes from ETag in the beginning and end.
metadata.ETag = strings.TrimPrefix(resp.Header.Get("ETag"), "\"")
metadata.ETag = strings.TrimSuffix(metadata.ETag, "\"")
objInfo.ETag = strings.TrimPrefix(resp.Header.Get("ETag"), "\"")
objInfo.ETag = strings.TrimSuffix(objInfo.ETag, "\"")
// A success here means data was written to server successfully.
metadata.Size = size
objInfo.Size = size
// Return here.
return metadata, nil
return objInfo, nil
}

View File

@ -17,6 +17,9 @@
package minio
import (
"bytes"
"encoding/xml"
"io"
"net/http"
"net/url"
)
@ -68,12 +71,142 @@ func (c Client) RemoveObject(bucketName, objectName string) error {
if err != nil {
return err
}
if resp != nil {
// if some unexpected error happened and max retry is reached, we want to let client know
if resp.StatusCode != http.StatusNoContent {
return httpRespToErrorResponse(resp, bucketName, objectName)
}
}
// DeleteObject always responds with http '204' even for
// objects which do not exist. So no need to handle them
// specifically.
return nil
}
// RemoveObjectError - container of Multi Delete S3 API error
type RemoveObjectError struct {
ObjectName string
Err error
}
// generateRemoveMultiObjects - generate the XML request for remove multi objects request
func generateRemoveMultiObjectsRequest(objects []string) []byte {
rmObjects := []deleteObject{}
for _, obj := range objects {
rmObjects = append(rmObjects, deleteObject{Key: obj})
}
xmlBytes, _ := xml.Marshal(deleteMultiObjects{Objects: rmObjects, Quiet: true})
return xmlBytes
}
// processRemoveMultiObjectsResponse - parse the remove multi objects web service
// and return the success/failure result status for each object
func processRemoveMultiObjectsResponse(body io.Reader, objects []string, errorCh chan<- RemoveObjectError) {
// Parse multi delete XML response
rmResult := &deleteMultiObjectsResult{}
err := xmlDecoder(body, rmResult)
if err != nil {
errorCh <- RemoveObjectError{ObjectName: "", Err: err}
return
}
// Fill deletion that returned an error.
for _, obj := range rmResult.UnDeletedObjects {
errorCh <- RemoveObjectError{
ObjectName: obj.Key,
Err: ErrorResponse{
Code: obj.Code,
Message: obj.Message,
},
}
}
}
// RemoveObjects remove multiples objects from a bucket.
// The list of objects to remove are received from objectsCh.
// Remove failures are sent back via error channel.
func (c Client) RemoveObjects(bucketName string, objectsCh <-chan string) <-chan RemoveObjectError {
errorCh := make(chan RemoveObjectError, 1)
// Validate if bucket name is valid.
if err := isValidBucketName(bucketName); err != nil {
defer close(errorCh)
errorCh <- RemoveObjectError{
Err: err,
}
return errorCh
}
// Validate objects channel to be properly allocated.
if objectsCh == nil {
defer close(errorCh)
errorCh <- RemoveObjectError{
Err: ErrInvalidArgument("Objects channel cannot be nil"),
}
return errorCh
}
// Generate and call MultiDelete S3 requests based on entries received from objectsCh
go func(errorCh chan<- RemoveObjectError) {
maxEntries := 1000
finish := false
urlValues := make(url.Values)
urlValues.Set("delete", "")
// Close error channel when Multi delete finishes.
defer close(errorCh)
// Loop over entries by 1000 and call MultiDelete requests
for {
if finish {
break
}
count := 0
var batch []string
// Try to gather 1000 entries
for object := range objectsCh {
batch = append(batch, object)
if count++; count >= maxEntries {
break
}
}
if count == 0 {
// Multi Objects Delete API doesn't accept empty object list, quit immediatly
break
}
if count < maxEntries {
// We didn't have 1000 entries, so this is the last batch
finish = true
}
// Generate remove multi objects XML request
removeBytes := generateRemoveMultiObjectsRequest(batch)
// Execute GET on bucket to list objects.
resp, err := c.executeMethod("POST", requestMetadata{
bucketName: bucketName,
queryValues: urlValues,
contentBody: bytes.NewReader(removeBytes),
contentLength: int64(len(removeBytes)),
contentMD5Bytes: sumMD5(removeBytes),
contentSHA256Bytes: sum256(removeBytes),
})
if err != nil {
for _, b := range batch {
errorCh <- RemoveObjectError{ObjectName: b, Err: err}
}
continue
}
// Process multiobjects remove xml response
processRemoveMultiObjectsResponse(resp.Body, batch, errorCh)
closeResponse(resp)
}
}(errorCh)
return errorCh
}
// RemoveIncompleteUpload aborts an partially uploaded object.
// Requires explicit authentication, no anonymous requests are allowed for multipart API.
func (c Client) RemoveIncompleteUpload(bucketName, objectName string) error {

View File

@ -206,3 +206,39 @@ type createBucketConfiguration struct {
XMLName xml.Name `xml:"http://s3.amazonaws.com/doc/2006-03-01/ CreateBucketConfiguration" json:"-"`
Location string `xml:"LocationConstraint"`
}
// deleteObject container for Delete element in MultiObjects Delete XML request
type deleteObject struct {
Key string
VersionID string `xml:"VersionId,omitempty"`
}
// deletedObject container for Deleted element in MultiObjects Delete XML response
type deletedObject struct {
Key string
VersionID string `xml:"VersionId,omitempty"`
// These fields are ignored.
DeleteMarker bool
DeleteMarkerVersionID string
}
// nonDeletedObject container for Error element (failed deletion) in MultiObjects Delete XML response
type nonDeletedObject struct {
Key string
Code string
Message string
}
// deletedMultiObjects container for MultiObjects Delete XML request
type deleteMultiObjects struct {
XMLName xml.Name `xml:"Delete"`
Quiet bool
Objects []deleteObject `xml:"Object"`
}
// deletedMultiObjectsResult container for MultiObjects Delete XML response
type deleteMultiObjectsResult struct {
XMLName xml.Name `xml:"DeleteResult"`
DeletedObjects []deletedObject `xml:"Deleted"`
UnDeletedObjects []nonDeletedObject `xml:"Error"`
}

View File

@ -21,6 +21,8 @@ import (
"strconv"
"strings"
"time"
"github.com/minio/minio-go/pkg/s3utils"
)
// BucketExists verify if bucket exists and you have permission to access it.
@ -49,6 +51,31 @@ func (c Client) BucketExists(bucketName string) (bool, error) {
return true, nil
}
// List of header keys to be filtered, usually
// from all S3 API http responses.
var defaultFilterKeys = []string{
"Transfer-Encoding",
"Accept-Ranges",
"Date",
"Server",
"Vary",
"x-amz-request-id",
"x-amz-id-2",
// Add new headers to be ignored.
}
// Extract only necessary metadata header key/values by
// filtering them out with a list of custom header keys.
func extractObjMetadata(header http.Header) http.Header {
filterKeys := append([]string{
"ETag",
"Content-Length",
"Last-Modified",
"Content-Type",
}, defaultFilterKeys...)
return filterHeader(header, filterKeys)
}
// StatObject verifies if object exists and you have permission to access.
func (c Client) StatObject(bucketName, objectName string) (ObjectInfo, error) {
// Input validation.
@ -78,17 +105,21 @@ func (c Client) StatObject(bucketName, objectName string) (ObjectInfo, error) {
md5sum := strings.TrimPrefix(resp.Header.Get("ETag"), "\"")
md5sum = strings.TrimSuffix(md5sum, "\"")
// Parse content length.
size, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
if err != nil {
return ObjectInfo{}, ErrorResponse{
Code: "InternalError",
Message: "Content-Length is invalid. " + reportIssue,
BucketName: bucketName,
Key: objectName,
RequestID: resp.Header.Get("x-amz-request-id"),
HostID: resp.Header.Get("x-amz-id-2"),
Region: resp.Header.Get("x-amz-bucket-region"),
// Content-Length is not valid for Google Cloud Storage, do not verify.
var size int64 = -1
if !s3utils.IsGoogleEndpoint(c.endpointURL) {
// Parse content length.
size, err = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
if err != nil {
return ObjectInfo{}, ErrorResponse{
Code: "InternalError",
Message: "Content-Length is invalid. " + reportIssue,
BucketName: bucketName,
Key: objectName,
RequestID: resp.Header.Get("x-amz-request-id"),
HostID: resp.Header.Get("x-amz-id-2"),
Region: resp.Header.Get("x-amz-bucket-region"),
}
}
}
// Parse Last-Modified has http time format.
@ -109,12 +140,19 @@ func (c Client) StatObject(bucketName, objectName string) (ObjectInfo, error) {
if contentType == "" {
contentType = "application/octet-stream"
}
// Extract only the relevant header keys describing the object.
// following function filters out a list of standard set of keys
// which are not part of object metadata.
metadata := extractObjMetadata(resp.Header)
// Save object metadata info.
var objectStat ObjectInfo
objectStat.ETag = md5sum
objectStat.Key = objectName
objectStat.Size = size
objectStat.LastModified = date
objectStat.ContentType = contentType
return objectStat, nil
return ObjectInfo{
ETag: md5sum,
Key: objectName,
Size: size,
LastModified: date,
ContentType: contentType,
Metadata: metadata,
}, nil
}

View File

@ -33,12 +33,18 @@ import (
"strings"
"sync"
"time"
"github.com/minio/minio-go/pkg/s3signer"
"github.com/minio/minio-go/pkg/s3utils"
)
// Client implements Amazon S3 compatible methods.
type Client struct {
/// Standard options.
// Parsed endpoint url provided by the user.
endpointURL url.URL
// AccessKeyID required for authorized requests.
accessKeyID string
// SecretAccessKey required for authorized requests.
@ -53,7 +59,6 @@ type Client struct {
appName string
appVersion string
}
endpointURL string
// Indicate whether we are using https or not
secure bool
@ -66,6 +71,9 @@ type Client struct {
isTraceEnabled bool
traceOutput io.Writer
// S3 specific accelerated endpoint.
s3AccelerateEndpoint string
// Random seed.
random *rand.Rand
}
@ -73,7 +81,7 @@ type Client struct {
// Global constants.
const (
libraryName = "minio-go"
libraryVersion = "2.0.1"
libraryVersion = "2.0.4"
)
// User Agent should always following the below style.
@ -116,13 +124,12 @@ func New(endpoint string, accessKeyID, secretAccessKey string, secure bool) (*Cl
if err != nil {
return nil, err
}
// Google cloud storage should be set to signature V2, force it if
// not.
if isGoogleEndpoint(clnt.endpointURL) {
// Google cloud storage should be set to signature V2, force it if not.
if s3utils.IsGoogleEndpoint(clnt.endpointURL) {
clnt.signature = SignatureV2
}
// If Amazon S3 set to signature v2.n
if isAmazonEndpoint(clnt.endpointURL) {
if s3utils.IsAmazonEndpoint(clnt.endpointURL) {
clnt.signature = SignatureV4
}
return clnt, nil
@ -151,6 +158,18 @@ func (r *lockedRandSource) Seed(seed int64) {
r.lk.Unlock()
}
// redirectHeaders copies all headers when following a redirect URL.
// This won't be needed anymore from go 1.8 (https://github.com/golang/go/issues/4800)
func redirectHeaders(req *http.Request, via []*http.Request) error {
if len(via) == 0 {
return nil
}
for key, val := range via[0].Header {
req.Header[key] = val
}
return nil
}
func privateNew(endpoint, accessKeyID, secretAccessKey string, secure bool) (*Client, error) {
// construct endpoint.
endpointURL, err := getEndpointURL(endpoint, secure)
@ -170,11 +189,12 @@ func privateNew(endpoint, accessKeyID, secretAccessKey string, secure bool) (*Cl
clnt.secure = secure
// Save endpoint URL, user agent for future uses.
clnt.endpointURL = endpointURL.String()
clnt.endpointURL = *endpointURL
// Instantiate http client and bucket location cache.
clnt.httpClient = &http.Client{
Transport: http.DefaultTransport,
Transport: http.DefaultTransport,
CheckRedirect: redirectHeaders,
}
// Instantiae bucket location cache.
@ -189,8 +209,7 @@ func privateNew(endpoint, accessKeyID, secretAccessKey string, secure bool) (*Cl
// SetAppInfo - add application details to user agent.
func (c *Client) SetAppInfo(appName string, appVersion string) {
// if app name and version is not set, we do not a new user
// agent.
// if app name and version not set, we do not set a new user agent.
if appName != "" && appVersion != "" {
c.appInfo = struct {
appName string
@ -241,8 +260,18 @@ func (c *Client) TraceOff() {
c.isTraceEnabled = false
}
// requestMetadata - is container for all the values to make a
// request.
// SetS3TransferAccelerate - turns s3 accelerated endpoint on or off for all your
// requests. This feature is only specific to S3 for all other endpoints this
// function does nothing. To read further details on s3 transfer acceleration
// please vist -
// http://docs.aws.amazon.com/AmazonS3/latest/dev/transfer-acceleration.html
func (c *Client) SetS3TransferAccelerate(accelerateEndpoint string) {
if s3utils.IsAmazonEndpoint(c.endpointURL) {
c.s3AccelerateEndpoint = accelerateEndpoint
}
}
// requestMetadata - is container for all the values to make a request.
type requestMetadata struct {
// If set newRequest presigns the URL.
presignURL bool
@ -262,6 +291,12 @@ type requestMetadata struct {
contentMD5Bytes []byte
}
// regCred matches credential string in HTTP header
var regCred = regexp.MustCompile("Credential=([A-Z0-9]+)/")
// regCred matches signature string in HTTP header
var regSign = regexp.MustCompile("Signature=([[0-9a-f]+)")
// Filter out signature value from Authorization header.
func (c Client) filterSignature(req *http.Request) {
// For anonymous requests, no need to filter.
@ -281,11 +316,9 @@ func (c Client) filterSignature(req *http.Request) {
origAuth := req.Header.Get("Authorization")
// Strip out accessKeyID from:
// Credential=<access-key-id>/<date>/<aws-region>/<aws-service>/aws4_request
regCred := regexp.MustCompile("Credential=([A-Z0-9]+)/")
newAuth := regCred.ReplaceAllString(origAuth, "Credential=**REDACTED**/")
// Strip out 256-bit signature from: Signature=<256-bit signature>
regSign := regexp.MustCompile("Signature=([[0-9a-f]+)")
newAuth = regSign.ReplaceAllString(newAuth, "Signature=**REDACTED**")
// Set a temporary redacted auth
@ -364,20 +397,35 @@ func (c Client) dumpHTTP(req *http.Request, resp *http.Response) error {
// do - execute http request.
func (c Client) do(req *http.Request) (*http.Response, error) {
// do the request.
resp, err := c.httpClient.Do(req)
if err != nil {
// Handle this specifically for now until future Golang
// versions fix this issue properly.
urlErr, ok := err.(*url.Error)
if ok && strings.Contains(urlErr.Err.Error(), "EOF") {
return nil, &url.Error{
Op: urlErr.Op,
URL: urlErr.URL,
Err: fmt.Errorf("Connection closed by foreign host %s. Retry again.", urlErr.URL),
var resp *http.Response
var err error
// Do the request in a loop in case of 307 http is met since golang still doesn't
// handle properly this situation (https://github.com/golang/go/issues/7912)
for {
resp, err = c.httpClient.Do(req)
if err != nil {
// Handle this specifically for now until future Golang
// versions fix this issue properly.
urlErr, ok := err.(*url.Error)
if ok && strings.Contains(urlErr.Err.Error(), "EOF") {
return nil, &url.Error{
Op: urlErr.Op,
URL: urlErr.URL,
Err: fmt.Errorf("Connection closed by foreign host %s. Retry again.", urlErr.URL),
}
}
return nil, err
}
// Redo the request with the new redirect url if http 307 is returned, quit the loop otherwise
if resp != nil && resp.StatusCode == http.StatusTemporaryRedirect {
newURL, err := url.Parse(resp.Header.Get("Location"))
if err != nil {
break
}
req.URL = newURL
} else {
break
}
return nil, err
}
// Response cannot be non-nil, report if its the case.
@ -467,6 +515,8 @@ func (c Client) executeMethod(method string, metadata requestMetadata) (res *htt
// Read the body to be saved later.
errBodyBytes, err := ioutil.ReadAll(res.Body)
// res.Body should be closed
closeResponse(res)
if err != nil {
return nil, err
}
@ -512,7 +562,7 @@ func (c Client) newRequest(method string, metadata requestMetadata) (req *http.R
// Default all requests to "us-east-1" or "cn-north-1" (china region)
location := "us-east-1"
if isAmazonChinaEndpoint(c.endpointURL) {
if s3utils.IsAmazonChinaEndpoint(c.endpointURL) {
// For china specifically we need to set everything to
// cn-north-1 for now, there is no easier way until AWS S3
// provides a cleaner compatible API across "us-east-1" and
@ -550,10 +600,10 @@ func (c Client) newRequest(method string, metadata requestMetadata) (req *http.R
}
if c.signature.isV2() {
// Presign URL with signature v2.
req = preSignV2(*req, c.accessKeyID, c.secretAccessKey, metadata.expires)
req = s3signer.PreSignV2(*req, c.accessKeyID, c.secretAccessKey, metadata.expires)
} else {
// Presign URL with signature v4.
req = preSignV4(*req, c.accessKeyID, c.secretAccessKey, location, metadata.expires)
req = s3signer.PreSignV4(*req, c.accessKeyID, c.secretAccessKey, location, metadata.expires)
}
return req, nil
}
@ -563,10 +613,10 @@ func (c Client) newRequest(method string, metadata requestMetadata) (req *http.R
req.Body = ioutil.NopCloser(metadata.contentBody)
}
// FIXEM: Enable this when Google Cloud Storage properly supports 100-continue.
// FIXME: Enable this when Google Cloud Storage properly supports 100-continue.
// Skip setting 'expect' header for Google Cloud Storage, there
// are some known issues - https://github.com/restic/restic/issues/520
if !isGoogleEndpoint(c.endpointURL) {
if !s3utils.IsGoogleEndpoint(c.endpointURL) && c.s3AccelerateEndpoint == "" {
// Set 'Expect' header for the request.
req.Header.Set("Expect", "100-continue")
}
@ -610,10 +660,10 @@ func (c Client) newRequest(method string, metadata requestMetadata) (req *http.R
if !c.anonymous {
if c.signature.isV2() {
// Add signature version '2' authorization header.
req = signV2(*req, c.accessKeyID, c.secretAccessKey)
req = s3signer.SignV2(*req, c.accessKeyID, c.secretAccessKey)
} else if c.signature.isV4() {
// Add signature version '4' authorization header.
req = signV4(*req, c.accessKeyID, c.secretAccessKey, location)
req = s3signer.SignV4(*req, c.accessKeyID, c.secretAccessKey, location)
}
}
@ -631,26 +681,34 @@ func (c Client) setUserAgent(req *http.Request) {
// makeTargetURL make a new target url.
func (c Client) makeTargetURL(bucketName, objectName, bucketLocation string, queryValues url.Values) (*url.URL, error) {
// Save host.
url, err := url.Parse(c.endpointURL)
if err != nil {
return nil, err
}
host := url.Host
host := c.endpointURL.Host
// For Amazon S3 endpoint, try to fetch location based endpoint.
if isAmazonEndpoint(c.endpointURL) {
// Fetch new host based on the bucket location.
host = getS3Endpoint(bucketLocation)
if s3utils.IsAmazonEndpoint(c.endpointURL) {
if c.s3AccelerateEndpoint != "" && bucketName != "" {
// http://docs.aws.amazon.com/AmazonS3/latest/dev/transfer-acceleration.html
// Disable transfer acceleration for non-compliant bucket names.
if strings.Contains(bucketName, ".") {
return nil, ErrTransferAccelerationBucket(bucketName)
}
// If transfer acceleration is requested set new host.
// For more details about enabling transfer acceleration read here.
// http://docs.aws.amazon.com/AmazonS3/latest/dev/transfer-acceleration.html
host = c.s3AccelerateEndpoint
} else {
// Fetch new host based on the bucket location.
host = getS3Endpoint(bucketLocation)
}
}
// Save scheme.
scheme := url.Scheme
scheme := c.endpointURL.Scheme
urlStr := scheme + "://" + host + "/"
// Make URL only if bucketName is available, otherwise use the
// endpoint URL.
if bucketName != "" {
// Save if target url will have buckets which suppport virtual host.
isVirtualHostStyle := isVirtualHostSupported(c.endpointURL, bucketName)
isVirtualHostStyle := s3utils.IsVirtualHostSupported(c.endpointURL, bucketName)
// If endpoint supports virtual host style use that always.
// Currently only S3 and Google Cloud Storage would support
@ -658,19 +716,19 @@ func (c Client) makeTargetURL(bucketName, objectName, bucketLocation string, que
if isVirtualHostStyle {
urlStr = scheme + "://" + bucketName + "." + host + "/"
if objectName != "" {
urlStr = urlStr + urlEncodePath(objectName)
urlStr = urlStr + s3utils.EncodePath(objectName)
}
} else {
// If not fall back to using path style.
urlStr = urlStr + bucketName + "/"
if objectName != "" {
urlStr = urlStr + urlEncodePath(objectName)
urlStr = urlStr + s3utils.EncodePath(objectName)
}
}
}
// If there are any query values, add them to the end.
if len(queryValues) > 0 {
urlStr = urlStr + "?" + queryEncode(queryValues)
urlStr = urlStr + "?" + s3utils.QueryEncode(queryValues)
}
u, err := url.Parse(urlStr)
if err != nil {

View File

@ -18,7 +18,6 @@ package minio
import (
"bytes"
crand "crypto/rand"
"errors"
"io"
"io/ioutil"
@ -43,10 +42,10 @@ func TestMakeBucketErrorV2(t *testing.T) {
// Instantiate new minio client object.
c, err := NewV2(
"s3.amazonaws.com",
os.Getenv("S3_ADDRESS"),
os.Getenv("ACCESS_KEY"),
os.Getenv("SECRET_KEY"),
true,
mustParseBool(os.Getenv("S3_SECURE")),
)
if err != nil {
t.Fatal("Error:", err)
@ -89,10 +88,10 @@ func TestGetObjectClosedTwiceV2(t *testing.T) {
// Instantiate new minio client object.
c, err := NewV2(
"s3.amazonaws.com",
os.Getenv("S3_ADDRESS"),
os.Getenv("ACCESS_KEY"),
os.Getenv("SECRET_KEY"),
true,
mustParseBool(os.Getenv("S3_SECURE")),
)
if err != nil {
t.Fatal("Error:", err)
@ -113,13 +112,8 @@ func TestGetObjectClosedTwiceV2(t *testing.T) {
t.Fatal("Error:", err, bucketName)
}
// Generate data more than 32K
buf := make([]byte, rand.Intn(1<<20)+32*1024)
_, err = io.ReadFull(crand.Reader, buf)
if err != nil {
t.Fatal("Error:", err)
}
// Generate data more than 32K.
buf := bytes.Repeat([]byte("h"), rand.Intn(1<<20)+32*1024)
// Save the data
objectName := randString(60, rand.NewSource(time.Now().UnixNano()), "")
@ -174,10 +168,10 @@ func TestRemovePartiallyUploadedV2(t *testing.T) {
// Instantiate new minio client object.
c, err := NewV2(
"s3.amazonaws.com",
os.Getenv("S3_ADDRESS"),
os.Getenv("ACCESS_KEY"),
os.Getenv("SECRET_KEY"),
true,
mustParseBool(os.Getenv("S3_SECURE")),
)
if err != nil {
t.Fatal("Error:", err)
@ -198,15 +192,18 @@ func TestRemovePartiallyUploadedV2(t *testing.T) {
t.Fatal("Error:", err, bucketName)
}
r := bytes.NewReader(bytes.Repeat([]byte("a"), 128*1024))
reader, writer := io.Pipe()
go func() {
i := 0
for i < 25 {
_, err = io.CopyN(writer, crand.Reader, 128*1024)
_, err = io.CopyN(writer, r, 128*1024)
if err != nil {
t.Fatal("Error:", err, bucketName)
}
i++
r.Seek(0, 0)
}
writer.CloseWithError(errors.New("Proactively closed to be verified later."))
}()
@ -241,10 +238,10 @@ func TestResumablePutObjectV2(t *testing.T) {
// Instantiate new minio client object.
c, err := NewV2(
"s3.amazonaws.com",
os.Getenv("S3_ADDRESS"),
os.Getenv("ACCESS_KEY"),
os.Getenv("SECRET_KEY"),
true,
mustParseBool(os.Getenv("S3_SECURE")),
)
if err != nil {
t.Fatal("Error:", err)
@ -271,8 +268,9 @@ func TestResumablePutObjectV2(t *testing.T) {
t.Fatal("Error:", err)
}
r := bytes.NewReader(bytes.Repeat([]byte("b"), 11*1024*1024))
// Copy 11MiB worth of random data.
n, err := io.CopyN(file, crand.Reader, 11*1024*1024)
n, err := io.CopyN(file, r, 11*1024*1024)
if err != nil {
t.Fatal("Error:", err)
}
@ -352,10 +350,10 @@ func TestFPutObjectV2(t *testing.T) {
// Instantiate new minio client object.
c, err := NewV2(
"s3.amazonaws.com",
os.Getenv("S3_ADDRESS"),
os.Getenv("ACCESS_KEY"),
os.Getenv("SECRET_KEY"),
true,
mustParseBool(os.Getenv("S3_SECURE")),
)
if err != nil {
t.Fatal("Error:", err)
@ -382,7 +380,8 @@ func TestFPutObjectV2(t *testing.T) {
t.Fatal("Error:", err)
}
n, err := io.CopyN(file, crand.Reader, 11*1024*1024)
r := bytes.NewReader(bytes.Repeat([]byte("b"), 11*1024*1024))
n, err := io.CopyN(file, r, 11*1024*1024)
if err != nil {
t.Fatal("Error:", err)
}
@ -500,10 +499,10 @@ func TestResumableFPutObjectV2(t *testing.T) {
// Instantiate new minio client object.
c, err := NewV2(
"s3.amazonaws.com",
os.Getenv("S3_ADDRESS"),
os.Getenv("ACCESS_KEY"),
os.Getenv("SECRET_KEY"),
true,
mustParseBool(os.Getenv("S3_SECURE")),
)
if err != nil {
t.Fatal("Error:", err)
@ -529,7 +528,8 @@ func TestResumableFPutObjectV2(t *testing.T) {
t.Fatal("Error:", err)
}
n, err := io.CopyN(file, crand.Reader, 11*1024*1024)
r := bytes.NewReader(bytes.Repeat([]byte("b"), 11*1024*1024))
n, err := io.CopyN(file, r, 11*1024*1024)
if err != nil {
t.Fatal("Error:", err)
}
@ -577,10 +577,10 @@ func TestMakeBucketRegionsV2(t *testing.T) {
// Instantiate new minio client object.
c, err := NewV2(
"s3.amazonaws.com",
os.Getenv("S3_ADDRESS"),
os.Getenv("ACCESS_KEY"),
os.Getenv("SECRET_KEY"),
true,
mustParseBool(os.Getenv("S3_SECURE")),
)
if err != nil {
t.Fatal("Error:", err)
@ -628,10 +628,10 @@ func TestGetObjectReadSeekFunctionalV2(t *testing.T) {
// Instantiate new minio client object.
c, err := NewV2(
"s3.amazonaws.com",
os.Getenv("S3_ADDRESS"),
os.Getenv("ACCESS_KEY"),
os.Getenv("SECRET_KEY"),
true,
mustParseBool(os.Getenv("S3_SECURE")),
)
if err != nil {
t.Fatal("Error:", err)
@ -652,15 +652,10 @@ func TestGetObjectReadSeekFunctionalV2(t *testing.T) {
t.Fatal("Error:", err, bucketName)
}
// Generate data more than 32K
buf := make([]byte, rand.Intn(1<<20)+32*1024)
// Generate data more than 32K.
buf := bytes.Repeat([]byte("2"), rand.Intn(1<<20)+32*1024)
_, err = io.ReadFull(crand.Reader, buf)
if err != nil {
t.Fatal("Error:", err)
}
// Save the data
// Save the data.
objectName := randString(60, rand.NewSource(time.Now().UnixNano()), "")
n, err := c.PutObject(bucketName, objectName, bytes.NewReader(buf), "binary/octet-stream")
if err != nil {
@ -716,7 +711,7 @@ func TestGetObjectReadSeekFunctionalV2(t *testing.T) {
}
var buffer1 bytes.Buffer
if n, err = io.CopyN(&buffer1, r, st.Size); err != nil {
if _, err = io.CopyN(&buffer1, r, st.Size); err != nil {
if err != io.EOF {
t.Fatal("Error:", err)
}
@ -766,10 +761,10 @@ func TestGetObjectReadAtFunctionalV2(t *testing.T) {
// Instantiate new minio client object.
c, err := NewV2(
"s3.amazonaws.com",
os.Getenv("S3_ADDRESS"),
os.Getenv("ACCESS_KEY"),
os.Getenv("SECRET_KEY"),
true,
mustParseBool(os.Getenv("S3_SECURE")),
)
if err != nil {
t.Fatal("Error:", err)
@ -791,12 +786,7 @@ func TestGetObjectReadAtFunctionalV2(t *testing.T) {
}
// Generate data more than 32K
buf := make([]byte, rand.Intn(1<<20)+32*1024)
_, err = io.ReadFull(crand.Reader, buf)
if err != nil {
t.Fatal("Error:", err)
}
buf := bytes.Repeat([]byte("8"), rand.Intn(1<<20)+32*1024)
// Save the data
objectName := randString(60, rand.NewSource(time.Now().UnixNano()), "")
@ -907,10 +897,10 @@ func TestCopyObjectV2(t *testing.T) {
// Instantiate new minio client object
c, err := NewV2(
"s3.amazonaws.com",
os.Getenv("S3_ADDRESS"),
os.Getenv("ACCESS_KEY"),
os.Getenv("SECRET_KEY"),
true,
mustParseBool(os.Getenv("S3_SECURE")),
)
if err != nil {
t.Fatal("Error:", err)
@ -938,12 +928,7 @@ func TestCopyObjectV2(t *testing.T) {
}
// Generate data more than 32K
buf := make([]byte, rand.Intn(1<<20)+32*1024)
_, err = io.ReadFull(crand.Reader, buf)
if err != nil {
t.Fatal("Error:", err)
}
buf := bytes.Repeat([]byte("9"), rand.Intn(1<<20)+32*1024)
// Save the data
objectName := randString(60, rand.NewSource(time.Now().UnixNano()), "")
@ -958,7 +943,7 @@ func TestCopyObjectV2(t *testing.T) {
}
// Set copy conditions.
copyConds := NewCopyConditions()
copyConds := CopyConditions{}
err = copyConds.SetModified(time.Date(2014, time.April, 0, 0, 0, 0, 0, time.UTC))
if err != nil {
t.Fatal("Error:", err)
@ -1029,10 +1014,10 @@ func TestFunctionalV2(t *testing.T) {
rand.Seed(time.Now().Unix())
c, err := NewV2(
"s3.amazonaws.com",
os.Getenv("S3_ADDRESS"),
os.Getenv("ACCESS_KEY"),
os.Getenv("SECRET_KEY"),
true,
mustParseBool(os.Getenv("S3_SECURE")),
)
if err != nil {
t.Fatal("Error:", err)
@ -1109,11 +1094,7 @@ func TestFunctionalV2(t *testing.T) {
objectName := bucketName + "unique"
// Generate data
buf := make([]byte, rand.Intn(1<<19))
_, err = io.ReadFull(crand.Reader, buf)
if err != nil {
t.Fatal("Error: ", err)
}
buf := bytes.Repeat([]byte("n"), rand.Intn(1<<19))
n, err := c.PutObject(bucketName, objectName, bytes.NewReader(buf), "")
if err != nil {
@ -1243,11 +1224,9 @@ func TestFunctionalV2(t *testing.T) {
if err != nil {
t.Fatal("Error: ", err)
}
buf = make([]byte, rand.Intn(1<<20))
_, err = io.ReadFull(crand.Reader, buf)
if err != nil {
t.Fatal("Error: ", err)
}
// Generate data more than 32K
buf = bytes.Repeat([]byte("1"), rand.Intn(1<<20)+32*1024)
req, err := http.NewRequest("PUT", presignedPutURL.String(), bytes.NewReader(buf))
if err != nil {
t.Fatal("Error: ", err)

File diff suppressed because it is too large Load Diff

View File

@ -18,11 +18,9 @@ package minio
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"
"testing"
@ -202,49 +200,6 @@ func TestTempFile(t *testing.T) {
}
}
// Tests url encoding.
func TestEncodeURL2Path(t *testing.T) {
type urlStrings struct {
objName string
encodedObjName string
}
bucketName := "bucketName"
want := []urlStrings{
{
objName: "本語",
encodedObjName: "%E6%9C%AC%E8%AA%9E",
},
{
objName: "本語.1",
encodedObjName: "%E6%9C%AC%E8%AA%9E.1",
},
{
objName: ">123>3123123",
encodedObjName: "%3E123%3E3123123",
},
{
objName: "test 1 2.txt",
encodedObjName: "test%201%202.txt",
},
{
objName: "test++ 1.txt",
encodedObjName: "test%2B%2B%201.txt",
},
}
for _, o := range want {
u, err := url.Parse(fmt.Sprintf("https://%s.s3.amazonaws.com/%s", bucketName, o.objName))
if err != nil {
t.Fatal("Error:", err)
}
urlPath := "/" + bucketName + "/" + o.encodedObjName
if urlPath != encodeURL2Path(u) {
t.Fatal("Error")
}
}
}
// Tests error response structure.
func TestErrorResponse(t *testing.T) {
var err error
@ -270,53 +225,6 @@ func TestErrorResponse(t *testing.T) {
}
}
// Tests signature calculation.
func TestSignatureCalculation(t *testing.T) {
req, err := http.NewRequest("GET", "https://s3.amazonaws.com", nil)
if err != nil {
t.Fatal("Error:", err)
}
req = signV4(*req, "", "", "us-east-1")
if req.Header.Get("Authorization") != "" {
t.Fatal("Error: anonymous credentials should not have Authorization header.")
}
req = preSignV4(*req, "", "", "us-east-1", 0)
if strings.Contains(req.URL.RawQuery, "X-Amz-Signature") {
t.Fatal("Error: anonymous credentials should not have Signature query resource.")
}
req = signV2(*req, "", "")
if req.Header.Get("Authorization") != "" {
t.Fatal("Error: anonymous credentials should not have Authorization header.")
}
req = preSignV2(*req, "", "", 0)
if strings.Contains(req.URL.RawQuery, "Signature") {
t.Fatal("Error: anonymous credentials should not have Signature query resource.")
}
req = signV4(*req, "ACCESS-KEY", "SECRET-KEY", "us-east-1")
if req.Header.Get("Authorization") == "" {
t.Fatal("Error: normal credentials should have Authorization header.")
}
req = preSignV4(*req, "ACCESS-KEY", "SECRET-KEY", "us-east-1", 0)
if !strings.Contains(req.URL.RawQuery, "X-Amz-Signature") {
t.Fatal("Error: normal credentials should have Signature query resource.")
}
req = signV2(*req, "ACCESS-KEY", "SECRET-KEY")
if req.Header.Get("Authorization") == "" {
t.Fatal("Error: normal credentials should have Authorization header.")
}
req = preSignV2(*req, "ACCESS-KEY", "SECRET-KEY", 0)
if !strings.Contains(req.URL.RawQuery, "Signature") {
t.Fatal("Error: normal credentials should not have Signature query resource.")
}
}
// Tests signature type.
func TestSignatureType(t *testing.T) {
clnt := Client{}
@ -354,11 +262,11 @@ func TestBucketPolicyTypes(t *testing.T) {
// Tests optimal part size.
func TestPartSize(t *testing.T) {
totalPartsCount, partSize, lastPartSize, err := optimalPartInfo(5000000000000000000)
_, _, _, err := optimalPartInfo(5000000000000000000)
if err == nil {
t.Fatal("Error: should fail")
}
totalPartsCount, partSize, lastPartSize, err = optimalPartInfo(5497558138880)
totalPartsCount, partSize, lastPartSize, err := optimalPartInfo(5497558138880)
if err != nil {
t.Fatal("Error: ", err)
}
@ -371,7 +279,7 @@ func TestPartSize(t *testing.T) {
if lastPartSize != 134217728 {
t.Fatalf("Error: expecting last part size of 241172480: got %v instead", lastPartSize)
}
totalPartsCount, partSize, lastPartSize, err = optimalPartInfo(5000000000)
_, partSize, _, err = optimalPartInfo(5000000000)
if err != nil {
t.Fatal("Error:", err)
}

View File

@ -23,6 +23,9 @@ import (
"path"
"strings"
"sync"
"github.com/minio/minio-go/pkg/s3signer"
"github.com/minio/minio-go/pkg/s3utils"
)
// bucketLocationCache - Provides simple mechanism to hold bucket
@ -85,7 +88,7 @@ func (c Client) getBucketLocation(bucketName string) (string, error) {
return location, nil
}
if isAmazonChinaEndpoint(c.endpointURL) {
if s3utils.IsAmazonChinaEndpoint(c.endpointURL) {
// For china specifically we need to set everything to
// cn-north-1 for now, there is no easier way until AWS S3
// provides a cleaner compatible API across "us-east-1" and
@ -160,10 +163,7 @@ func (c Client) getBucketLocationRequest(bucketName string) (*http.Request, erro
urlValues.Set("location", "")
// Set get bucket location always as path style.
targetURL, err := url.Parse(c.endpointURL)
if err != nil {
return nil, err
}
targetURL := c.endpointURL
targetURL.Path = path.Join(bucketName, "") + "/"
targetURL.RawQuery = urlValues.Encode()
@ -189,9 +189,9 @@ func (c Client) getBucketLocationRequest(bucketName string) (*http.Request, erro
// Sign the request.
if c.signature.isV4() {
req = signV4(*req, c.accessKeyID, c.secretAccessKey, "us-east-1")
req = s3signer.SignV4(*req, c.accessKeyID, c.secretAccessKey, "us-east-1")
} else if c.signature.isV2() {
req = signV2(*req, c.accessKeyID, c.secretAccessKey)
req = s3signer.SignV2(*req, c.accessKeyID, c.secretAccessKey)
}
return req, nil
}

View File

@ -26,6 +26,8 @@ import (
"path"
"reflect"
"testing"
"github.com/minio/minio-go/pkg/s3signer"
)
// Test validates `newBucketLocationCache`.
@ -70,14 +72,12 @@ func TestGetBucketLocationRequest(t *testing.T) {
urlValues.Set("location", "")
// Set get bucket location always as path style.
targetURL, err := url.Parse(c.endpointURL)
if err != nil {
return nil, err
}
targetURL := c.endpointURL
targetURL.Path = path.Join(bucketName, "") + "/"
targetURL.RawQuery = urlValues.Encode()
// Get a new HTTP request for the method.
var err error
req, err = http.NewRequest("GET", targetURL.String(), nil)
if err != nil {
return nil, err
@ -93,9 +93,9 @@ func TestGetBucketLocationRequest(t *testing.T) {
// Sign the request.
if c.signature.isV4() {
req = signV4(*req, c.accessKeyID, c.secretAccessKey, "us-east-1")
req = s3signer.SignV4(*req, c.accessKeyID, c.secretAccessKey, "us-east-1")
} else if c.signature.isV2() {
req = signV2(*req, c.accessKeyID, c.secretAccessKey)
req = s3signer.SignV2(*req, c.accessKeyID, c.secretAccessKey)
}
return req, nil

View File

@ -84,7 +84,7 @@ func (arn Arn) String() string {
// NotificationConfig - represents one single notification configuration
// such as topic, queue or lambda configuration.
type NotificationConfig struct {
Id string `xml:"Id,omitempty"`
ID string `xml:"Id,omitempty"`
Arn Arn `xml:"-"`
Events []NotificationEventType `xml:"Event"`
Filter *Filter `xml:"Filter,omitempty"`

View File

@ -18,7 +18,7 @@ package minio
/// Multipart upload defaults.
// miniPartSize - minimum part size 5MiB per object after which
// miniPartSize - minimum part size 64MiB per object after which
// putObject behaves internally as multipart.
const minPartSize = 1024 * 1024 * 64
@ -44,3 +44,9 @@ const optimalReadBufferSize = 1024 * 1024 * 5
// unsignedPayload - value to be set to X-Amz-Content-Sha256 header when
// we don't want to sign the request payload
const unsignedPayload = "UNSIGNED-PAYLOAD"
// Signature related constants.
const (
signV4Algorithm = "AWS4-HMAC-SHA256"
iso8601DateFormat = "20060102T150405Z"
)

View File

@ -41,11 +41,13 @@ type CopyConditions struct {
conditions []copyCondition
}
// NewCopyConditions - Instantiate new list of conditions.
// NewCopyConditions - Instantiate new list of conditions. This
// function is left behind for backward compatibility. The idiomatic
// way to set an empty set of copy conditions is,
// ``copyConditions := CopyConditions{}``.
//
func NewCopyConditions() CopyConditions {
return CopyConditions{
conditions: make([]copyCondition, 0),
}
return CopyConditions{}
}
// SetMatchETag - set match etag.

File diff suppressed because it is too large Load Diff

View File

@ -33,7 +33,7 @@ func main() {
// New returns an Amazon S3 compatible client object. API compatibility (v2 or v4) is automatically
// determined based on the Endpoint value.
minioClient, err := minio.New("play.minio.io:9000", "YOUR-ACCESSKEYID", "YOUR-SECRETACCESSKEY", true)
minioClient, err := minio.New("play.minio.io:9000", "YOUR-ACCESS", "YOUR-SECRET", true)
if err != nil {
log.Fatalln(err)
}
@ -46,30 +46,11 @@ func main() {
// Indicate to our routine to exit cleanly upon return.
defer close(doneCh)
// Fetch the bucket location.
location, err := minioClient.GetBucketLocation("YOUR-BUCKET")
if err != nil {
log.Fatalln(err)
}
// Construct a new account Arn.
accountArn := minio.NewArn("minio", "sns", location, "your-account-id", "listen")
topicConfig := minio.NewNotificationConfig(accountArn)
topicConfig.AddEvents(minio.ObjectCreatedAll, minio.ObjectRemovedAll)
topicConfig.AddFilterPrefix("photos/")
topicConfig.AddFilterSuffix(".jpg")
// Now, set all previously created notification configs
bucketNotification := minio.BucketNotification{}
bucketNotification.AddTopic(topicConfig)
err = minioClient.SetBucketNotification("YOUR-BUCKET", bucketNotification)
if err != nil {
log.Fatalln("Error: " + err.Error())
}
log.Println("Success")
// Listen for bucket notifications on "mybucket" filtered by accountArn "arn:minio:sns:<location>:<your-account-id>:listen".
for notificationInfo := range minioClient.ListenBucketNotification("YOUR-BUCKET", accountArn, doneCh) {
// Listen for bucket notifications on "mybucket" filtered by prefix, suffix and events.
for notificationInfo := range minioClient.ListenBucketNotification("YOUR-BUCKET", "PREFIX", "SUFFIX", []string{
"s3:ObjectCreated:*",
"s3:ObjectRemoved:*",
}, doneCh) {
if notificationInfo.Err != nil {
log.Fatalln(notificationInfo.Err)
}

View File

@ -38,10 +38,14 @@ func main() {
log.Fatalln(err)
}
err = s3Client.BucketExists("my-bucketname")
found, err := s3Client.BucketExists("my-bucketname")
if err != nil {
log.Fatalln(err)
}
log.Println("Success")
if found {
log.Println("Bucket found.")
} else {
log.Println("Bucket not found.")
}
}

View File

@ -45,7 +45,7 @@ func main() {
// All following conditions are allowed and can be combined together.
// Set copy conditions.
var copyConds = minio.NewCopyConditions()
var copyConds = minio.CopyConditions{}
// Set modified condition, copy object modified since 2014 April.
copyConds.SetModified(time.Date(2014, time.April, 0, 0, 0, 0, 0, time.UTC))

View File

@ -0,0 +1,56 @@
// +build ignore
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2015, 2016 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"log"
"github.com/minio/minio-go"
)
func main() {
// Note: YOUR-ACCESSKEYID, YOUR-SECRETACCESSKEY and my-bucketname are
// dummy values, please replace them with original values.
// Requests are always secure (HTTPS) by default. Set secure=false to enable insecure (HTTP) access.
// This boolean value is the last argument for New().
// New returns an Amazon S3 compatible client object. API compatibility (v2 or v4) is automatically
// determined based on the Endpoint value.
s3Client, err := minio.New("s3.amazonaws.com", "YOUR-ACCESSKEYID", "YOUR-SECRETACCESSKEY", true)
if err != nil {
log.Fatalln(err)
}
// s3Client.TraceOn(os.Stderr)
// Fetch the policy at 'my-objectprefix'.
policies, err := s3Client.ListBucketPolicies("my-bucketname", "my-objectprefix")
if err != nil {
log.Fatalln(err)
}
// ListBucketPolicies returns a map of objects policy rules and their associated permissions
// e.g. mybucket/downloadfolder/* => readonly
// mybucket/shared/* => readwrite
for resource, permission := range policies {
log.Println(resource, " => ", permission)
}
}

View File

@ -0,0 +1,56 @@
// +build ignore
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2016 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"log"
"os"
"github.com/minio/minio-go"
)
func main() {
// Note: YOUR-ACCESSKEYID, YOUR-SECRETACCESSKEY, my-testfile, my-bucketname and
// my-objectname are dummy values, please replace them with original values.
// Requests are always secure (HTTPS) by default. Set secure=false to enable insecure (HTTP) access.
// This boolean value is the last argument for New().
// New returns an Amazon S3 compatible client object. API compatibility (v2 or v4) is automatically
// determined based on the Endpoint value.
s3Client, err := minio.New("s3.amazonaws.com", "YOUR-ACCESSKEYID", "YOUR-SECRETACCESSKEY", true)
if err != nil {
log.Fatalln(err)
}
// Enable S3 transfer accelerate endpoint.
s3Client.S3TransferAccelerate("s3-accelerate.amazonaws.com")
object, err := os.Open("my-testfile")
if err != nil {
log.Fatalln(err)
}
defer object.Close()
n, err := s3Client.PutObject("my-bucketname", "my-objectname", object, "application/octet-stream")
if err != nil {
log.Fatalln(err)
}
log.Println("Uploaded", "my-objectname", " of size: ", n, "Successfully.")
}

View File

@ -0,0 +1,61 @@
// +build ignore
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2015 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"log"
"strconv"
"github.com/minio/minio-go"
)
func main() {
// Note: YOUR-ACCESSKEYID, YOUR-SECRETACCESSKEY, my-bucketname and my-objectname
// are dummy values, please replace them with original values.
// Requests are always secure (HTTPS) by default. Set secure=false to enable insecure (HTTP) access.
// This boolean value is the last argument for New().
// New returns an Amazon S3 compatible client object. API compatibility (v2 or v4) is automatically
// determined based on the Endpoint value.
s3Client, err := minio.New("s3.amazonaws.com", "YOUR-ACCESSKEYID", "YOUR-SECRETACCESSKEY", true)
if err != nil {
log.Fatalln(err)
}
objectsCh := make(chan string)
// Send object names that are needed to be removed to objectsCh
go func() {
defer close(objectsCh)
for i := 0; i < 10; i++ {
objectsCh <- "/path/to/my-objectname" + strconv.Itoa(i)
}
}()
// Call RemoveObjects API
errorCh := s3Client.RemoveObjects("my-bucketname", objectsCh)
// Print errors received from RemoveObjects API
for e := range errorCh {
log.Fatalln("Failed to remove " + e.ObjectName + ", error: " + e.Err.Error())
}
log.Println("Success")
}

View File

@ -22,6 +22,7 @@ import (
"log"
"github.com/minio/minio-go"
"github.com/minio/minio-go/pkg/policy"
)
func main() {
@ -41,11 +42,11 @@ func main() {
// s3Client.TraceOn(os.Stderr)
// Description of policy input.
// minio.BucketPolicyNone - Remove any previously applied bucket policy at a prefix.
// minio.BucketPolicyReadOnly - Set read-only operations at a prefix.
// minio.BucketPolicyWriteOnly - Set write-only operations at a prefix.
// minio.BucketPolicyReadWrite - Set read-write operations at a prefix.
err = s3Client.SetBucketPolicy("my-bucketname", "my-objectprefix", minio.BucketPolicyReadWrite)
// policy.BucketPolicyNone - Remove any previously applied bucket policy at a prefix.
// policy.BucketPolicyReadOnly - Set read-only operations at a prefix.
// policy.BucketPolicyWriteOnly - Set write-only operations at a prefix.
// policy.BucketPolicyReadWrite - Set read-write operations at a prefix.
err = s3Client.SetBucketPolicy("my-bucketname", "my-objectprefix", policy.BucketPolicyReadWrite)
if err != nil {
log.Fatalln(err)
}

View File

@ -34,7 +34,7 @@ const (
BucketPolicyWriteOnly = "writeonly"
)
// isValidBucketPolicy - Is provided policy value supported.
// IsValidBucketPolicy - returns true if policy is valid and supported, false otherwise.
func (p BucketPolicy) IsValidBucketPolicy() bool {
switch p {
case BucketPolicyNone, BucketPolicyReadOnly, BucketPolicyReadWrite, BucketPolicyWriteOnly:
@ -508,7 +508,7 @@ func getObjectPolicy(statement Statement) (readOnly bool, writeOnly bool) {
return readOnly, writeOnly
}
// Returns policy of given bucket name, prefix in given statements.
// GetPolicy - Returns policy of given bucket name, prefix in given statements.
func GetPolicy(statements []Statement, bucketName string, prefix string) BucketPolicy {
bucketResource := awsResourcePrefix + bucketName
objectResource := awsResourcePrefix + bucketName + "/" + prefix + "*"
@ -563,8 +563,34 @@ func GetPolicy(statements []Statement, bucketName string, prefix string) BucketP
return policy
}
// Returns new statements containing policy of given bucket name and
// prefix are appended.
// GetPolicies - returns a map of policies rules of given bucket name, prefix in given statements.
func GetPolicies(statements []Statement, bucketName string) map[string]BucketPolicy {
policyRules := map[string]BucketPolicy{}
objResources := set.NewStringSet()
// Search all resources related to objects policy
for _, s := range statements {
for r := range s.Resources {
if strings.HasPrefix(r, awsResourcePrefix+bucketName+"/") {
objResources.Add(r)
}
}
}
// Pretend that policy resource as an actual object and fetch its policy
for r := range objResources {
// Put trailing * if exists in asterisk
asterisk := ""
if strings.HasSuffix(r, "*") {
r = r[:len(r)-1]
asterisk = "*"
}
objectPath := r[len(awsResourcePrefix+bucketName)+1 : len(r)]
p := GetPolicy(statements, bucketName, objectPath)
policyRules[bucketName+"/"+objectPath+asterisk] = p
}
return policyRules
}
// SetPolicy - Returns new statements containing policy of given bucket name and prefix are appended.
func SetPolicy(statements []Statement, policy BucketPolicy, bucketName string, prefix string) []Statement {
out := removeStatements(statements, bucketName, prefix)
// fmt.Println("out = ")

View File

@ -19,6 +19,7 @@ package policy
import (
"encoding/json"
"fmt"
"reflect"
"testing"
"github.com/minio/minio-go/pkg/set"
@ -1376,6 +1377,104 @@ func TestGetObjectPolicy(t *testing.T) {
}
}
// GetPolicyRules is called and the result is validated
func TestListBucketPolicies(t *testing.T) {
// Condition for read objects
downloadCondMap := make(ConditionMap)
downloadCondKeyMap := make(ConditionKeyMap)
downloadCondKeyMap.Add("s3:prefix", set.CreateStringSet("download"))
downloadCondMap.Add("StringEquals", downloadCondKeyMap)
// Condition for readwrite objects
downloadUploadCondMap := make(ConditionMap)
downloadUploadCondKeyMap := make(ConditionKeyMap)
downloadUploadCondKeyMap.Add("s3:prefix", set.CreateStringSet("both"))
downloadUploadCondMap.Add("StringEquals", downloadUploadCondKeyMap)
testCases := []struct {
statements []Statement
bucketName string
prefix string
expectedResult map[string]BucketPolicy
}{
// Empty statements, bucket name and prefix.
{[]Statement{}, "", "", map[string]BucketPolicy{}},
// Non-empty statements, empty bucket name and empty prefix.
{[]Statement{{
Actions: readOnlyBucketActions,
Effect: "Allow",
Principal: User{AWS: set.CreateStringSet("*")},
Resources: set.CreateStringSet("arn:aws:s3:::mybucket"),
}}, "", "", map[string]BucketPolicy{}},
// Empty statements, non-empty bucket name and empty prefix.
{[]Statement{}, "mybucket", "", map[string]BucketPolicy{}},
// Readonly object statement
{[]Statement{
{
Actions: commonBucketActions,
Effect: "Allow",
Principal: User{AWS: set.CreateStringSet("*")},
Resources: set.CreateStringSet("arn:aws:s3:::mybucket"),
},
{
Actions: readOnlyBucketActions,
Effect: "Allow",
Principal: User{AWS: set.CreateStringSet("*")},
Conditions: downloadCondMap,
Resources: set.CreateStringSet("arn:aws:s3:::mybucket"),
},
{
Actions: readOnlyObjectActions,
Effect: "Allow",
Principal: User{AWS: set.CreateStringSet("*")},
Resources: set.CreateStringSet("arn:aws:s3:::mybucket/download*"),
}}, "mybucket", "", map[string]BucketPolicy{"mybucket/download*": BucketPolicyReadOnly}},
// Write Only
{[]Statement{
{
Actions: commonBucketActions.Union(writeOnlyBucketActions),
Effect: "Allow",
Principal: User{AWS: set.CreateStringSet("*")},
Resources: set.CreateStringSet("arn:aws:s3:::mybucket"),
},
{
Actions: writeOnlyObjectActions,
Effect: "Allow",
Principal: User{AWS: set.CreateStringSet("*")},
Resources: set.CreateStringSet("arn:aws:s3:::mybucket/upload*"),
}}, "mybucket", "", map[string]BucketPolicy{"mybucket/upload*": BucketPolicyWriteOnly}},
// Readwrite
{[]Statement{
{
Actions: commonBucketActions.Union(writeOnlyBucketActions),
Effect: "Allow",
Principal: User{AWS: set.CreateStringSet("*")},
Resources: set.CreateStringSet("arn:aws:s3:::mybucket"),
},
{
Actions: readOnlyBucketActions,
Effect: "Allow",
Principal: User{AWS: set.CreateStringSet("*")},
Conditions: downloadUploadCondMap,
Resources: set.CreateStringSet("arn:aws:s3:::mybucket"),
},
{
Actions: writeOnlyObjectActions.Union(readOnlyObjectActions),
Effect: "Allow",
Principal: User{AWS: set.CreateStringSet("*")},
Resources: set.CreateStringSet("arn:aws:s3:::mybucket/both*"),
}}, "mybucket", "", map[string]BucketPolicy{"mybucket/both*": BucketPolicyReadWrite}},
}
for _, testCase := range testCases {
policyRules := GetPolicies(testCase.statements, testCase.bucketName)
if !reflect.DeepEqual(testCase.expectedResult, policyRules) {
t.Fatalf("%+v:\n expected: %+v, got: %+v", testCase, testCase.expectedResult, policyRules)
}
}
}
// GetPolicy() is called and the result is validated.
func TestGetPolicy(t *testing.T) {
helloCondMap := make(ConditionMap)

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package minio
package s3signer
import (
"bytes"
@ -29,6 +29,8 @@ import (
"strconv"
"strings"
"time"
"github.com/minio/minio-go/pkg/s3utils"
)
// Signature and API related constants.
@ -45,22 +47,22 @@ func encodeURL2Path(u *url.URL) (path string) {
bucketName := hostSplits[0]
path = "/" + bucketName
path += u.Path
path = urlEncodePath(path)
path = s3utils.EncodePath(path)
return
}
if strings.HasSuffix(u.Host, ".storage.googleapis.com") {
path = "/" + strings.TrimSuffix(u.Host, ".storage.googleapis.com")
path += u.Path
path = urlEncodePath(path)
path = s3utils.EncodePath(path)
return
}
path = urlEncodePath(u.Path)
path = s3utils.EncodePath(u.Path)
return
}
// preSignV2 - presign the request in following style.
// PreSignV2 - presign the request in following style.
// https://${S3_BUCKET}.s3.amazonaws.com/${S3_OBJECT}?AWSAccessKeyId=${S3_ACCESS_KEY}&Expires=${TIMESTAMP}&Signature=${SIGNATURE}.
func preSignV2(req http.Request, accessKeyID, secretAccessKey string, expires int64) *http.Request {
func PreSignV2(req http.Request, accessKeyID, secretAccessKey string, expires int64) *http.Request {
// Presign is not needed for anonymous credentials.
if accessKeyID == "" || secretAccessKey == "" {
return &req
@ -95,18 +97,18 @@ func preSignV2(req http.Request, accessKeyID, secretAccessKey string, expires in
query.Set("Expires", strconv.FormatInt(epochExpires, 10))
// Encode query and save.
req.URL.RawQuery = queryEncode(query)
req.URL.RawQuery = s3utils.QueryEncode(query)
// Save signature finally.
req.URL.RawQuery += "&Signature=" + urlEncodePath(signature)
req.URL.RawQuery += "&Signature=" + s3utils.EncodePath(signature)
// Return.
return &req
}
// postPresignSignatureV2 - presigned signature for PostPolicy
// PostPresignSignatureV2 - presigned signature for PostPolicy
// request.
func postPresignSignatureV2(policyBase64, secretAccessKey string) string {
func PostPresignSignatureV2(policyBase64, secretAccessKey string) string {
hm := hmac.New(sha1.New, []byte(secretAccessKey))
hm.Write([]byte(policyBase64))
signature := base64.StdEncoding.EncodeToString(hm.Sum(nil))
@ -129,8 +131,8 @@ func postPresignSignatureV2(policyBase64, secretAccessKey string) string {
//
// CanonicalizedProtocolHeaders = <described below>
// signV2 sign the request before Do() (AWS Signature Version 2).
func signV2(req http.Request, accessKeyID, secretAccessKey string) *http.Request {
// SignV2 sign the request before Do() (AWS Signature Version 2).
func SignV2(req http.Request, accessKeyID, secretAccessKey string) *http.Request {
// Signature calculation is not needed for anonymous credentials.
if accessKeyID == "" || secretAccessKey == "" {
return &req
@ -257,6 +259,7 @@ func writeCanonicalizedHeaders(buf *bytes.Buffer, req http.Request) {
// have signature-related issues
var resourceList = []string{
"acl",
"delete",
"location",
"logging",
"notification",
@ -286,7 +289,7 @@ func writeCanonicalizedResource(buf *bytes.Buffer, req http.Request, isPreSign b
// Get encoded URL path.
if len(requestURL.Query()) > 0 {
// Keep the usual queries unescaped for string to sign.
query, _ := url.QueryUnescape(queryEncode(requestURL.Query()))
query, _ := url.QueryUnescape(s3utils.QueryEncode(requestURL.Query()))
path = path + "?" + query
}
buf.WriteString(path)

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package minio
package s3signer
import (
"sort"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package minio
package s3signer
import (
"bytes"
@ -24,6 +24,8 @@ import (
"strconv"
"strings"
"time"
"github.com/minio/minio-go/pkg/s3utils"
)
// Signature and API related constants.
@ -101,8 +103,8 @@ func getScope(location string, t time.Time) string {
return scope
}
// getCredential generate a credential string.
func getCredential(accessKeyID, location string, t time.Time) string {
// GetCredential generate a credential string.
func GetCredential(accessKeyID, location string, t time.Time) string {
scope := getScope(location, t)
return accessKeyID + "/" + scope
}
@ -185,7 +187,7 @@ func getCanonicalRequest(req http.Request) string {
req.URL.RawQuery = strings.Replace(req.URL.Query().Encode(), "+", "%20", -1)
canonicalRequest := strings.Join([]string{
req.Method,
urlEncodePath(req.URL.Path),
s3utils.EncodePath(req.URL.Path),
req.URL.RawQuery,
getCanonicalHeaders(req),
getSignedHeaders(req),
@ -202,9 +204,9 @@ func getStringToSignV4(t time.Time, location, canonicalRequest string) string {
return stringToSign
}
// preSignV4 presign the request, in accordance with
// PreSignV4 presign the request, in accordance with
// http://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html.
func preSignV4(req http.Request, accessKeyID, secretAccessKey, location string, expires int64) *http.Request {
func PreSignV4(req http.Request, accessKeyID, secretAccessKey, location string, expires int64) *http.Request {
// Presign is not needed for anonymous credentials.
if accessKeyID == "" || secretAccessKey == "" {
return &req
@ -214,7 +216,7 @@ func preSignV4(req http.Request, accessKeyID, secretAccessKey, location string,
t := time.Now().UTC()
// Get credential string.
credential := getCredential(accessKeyID, location, t)
credential := GetCredential(accessKeyID, location, t)
// Get all signed headers.
signedHeaders := getSignedHeaders(req)
@ -246,9 +248,9 @@ func preSignV4(req http.Request, accessKeyID, secretAccessKey, location string,
return &req
}
// postPresignSignatureV4 - presigned signature for PostPolicy
// PostPresignSignatureV4 - presigned signature for PostPolicy
// requests.
func postPresignSignatureV4(policyBase64 string, t time.Time, secretAccessKey, location string) string {
func PostPresignSignatureV4(policyBase64 string, t time.Time, secretAccessKey, location string) string {
// Get signining key.
signingkey := getSigningKey(secretAccessKey, location, t)
// Calculate signature.
@ -256,9 +258,9 @@ func postPresignSignatureV4(policyBase64 string, t time.Time, secretAccessKey, l
return signature
}
// signV4 sign the request before Do(), in accordance with
// SignV4 sign the request before Do(), in accordance with
// http://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-authenticating-requests.html.
func signV4(req http.Request, accessKeyID, secretAccessKey, location string) *http.Request {
func SignV4(req http.Request, accessKeyID, secretAccessKey, location string) *http.Request {
// Signature calculation is not needed for anonymous credentials.
if accessKeyID == "" || secretAccessKey == "" {
return &req
@ -280,7 +282,7 @@ func signV4(req http.Request, accessKeyID, secretAccessKey, location string) *ht
signingKey := getSigningKey(secretAccessKey, location, t)
// Get credential string.
credential := getCredential(accessKeyID, location, t)
credential := GetCredential(accessKeyID, location, t)
// Get all signed headers.
signedHeaders := getSignedHeaders(req)

View File

@ -0,0 +1,70 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2015, 2016 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3signer
import (
"net/http"
"strings"
"testing"
)
// Tests signature calculation.
func TestSignatureCalculation(t *testing.T) {
req, err := http.NewRequest("GET", "https://s3.amazonaws.com", nil)
if err != nil {
t.Fatal("Error:", err)
}
req = SignV4(*req, "", "", "us-east-1")
if req.Header.Get("Authorization") != "" {
t.Fatal("Error: anonymous credentials should not have Authorization header.")
}
req = PreSignV4(*req, "", "", "us-east-1", 0)
if strings.Contains(req.URL.RawQuery, "X-Amz-Signature") {
t.Fatal("Error: anonymous credentials should not have Signature query resource.")
}
req = SignV2(*req, "", "")
if req.Header.Get("Authorization") != "" {
t.Fatal("Error: anonymous credentials should not have Authorization header.")
}
req = PreSignV2(*req, "", "", 0)
if strings.Contains(req.URL.RawQuery, "Signature") {
t.Fatal("Error: anonymous credentials should not have Signature query resource.")
}
req = SignV4(*req, "ACCESS-KEY", "SECRET-KEY", "us-east-1")
if req.Header.Get("Authorization") == "" {
t.Fatal("Error: normal credentials should have Authorization header.")
}
req = PreSignV4(*req, "ACCESS-KEY", "SECRET-KEY", "us-east-1", 0)
if !strings.Contains(req.URL.RawQuery, "X-Amz-Signature") {
t.Fatal("Error: normal credentials should have Signature query resource.")
}
req = SignV2(*req, "ACCESS-KEY", "SECRET-KEY")
if req.Header.Get("Authorization") == "" {
t.Fatal("Error: normal credentials should have Authorization header.")
}
req = PreSignV2(*req, "ACCESS-KEY", "SECRET-KEY", 0)
if !strings.Contains(req.URL.RawQuery, "Signature") {
t.Fatal("Error: normal credentials should not have Signature query resource.")
}
}

View File

@ -0,0 +1,39 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2015 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3signer
import (
"crypto/hmac"
"crypto/sha256"
)
// unsignedPayload - value to be set to X-Amz-Content-Sha256 header when
const unsignedPayload = "UNSIGNED-PAYLOAD"
// sum256 calculate sha256 sum for an input byte array.
func sum256(data []byte) []byte {
hash := sha256.New()
hash.Write(data)
return hash.Sum(nil)
}
// sumHMAC calculate hmac between two input byte array.
func sumHMAC(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}

View File

@ -0,0 +1,66 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2015 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3signer
import (
"fmt"
"net/url"
"testing"
)
// Tests url encoding.
func TestEncodeURL2Path(t *testing.T) {
type urlStrings struct {
objName string
encodedObjName string
}
bucketName := "bucketName"
want := []urlStrings{
{
objName: "本語",
encodedObjName: "%E6%9C%AC%E8%AA%9E",
},
{
objName: "本語.1",
encodedObjName: "%E6%9C%AC%E8%AA%9E.1",
},
{
objName: ">123>3123123",
encodedObjName: "%3E123%3E3123123",
},
{
objName: "test 1 2.txt",
encodedObjName: "test%201%202.txt",
},
{
objName: "test++ 1.txt",
encodedObjName: "test%2B%2B%201.txt",
},
}
for _, o := range want {
u, err := url.Parse(fmt.Sprintf("https://%s.s3.amazonaws.com/%s", bucketName, o.objName))
if err != nil {
t.Fatal("Error:", err)
}
urlPath := "/" + bucketName + "/" + o.encodedObjName
if urlPath != encodeURL2Path(u) {
t.Fatal("Error")
}
}
}

View File

@ -0,0 +1,183 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2016 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3utils
import (
"bytes"
"encoding/hex"
"net"
"net/url"
"regexp"
"sort"
"strings"
"unicode/utf8"
)
// Sentinel URL is the default url value which is invalid.
var sentinelURL = url.URL{}
// IsValidDomain validates if input string is a valid domain name.
func IsValidDomain(host string) bool {
// See RFC 1035, RFC 3696.
host = strings.TrimSpace(host)
if len(host) == 0 || len(host) > 255 {
return false
}
// host cannot start or end with "-"
if host[len(host)-1:] == "-" || host[:1] == "-" {
return false
}
// host cannot start or end with "_"
if host[len(host)-1:] == "_" || host[:1] == "_" {
return false
}
// host cannot start or end with a "."
if host[len(host)-1:] == "." || host[:1] == "." {
return false
}
// All non alphanumeric characters are invalid.
if strings.ContainsAny(host, "`~!@#$%^&*()+={}[]|\\\"';:><?/") {
return false
}
// No need to regexp match, since the list is non-exhaustive.
// We let it valid and fail later.
return true
}
// IsValidIP parses input string for ip address validity.
func IsValidIP(ip string) bool {
return net.ParseIP(ip) != nil
}
// IsVirtualHostSupported - verifies if bucketName can be part of
// virtual host. Currently only Amazon S3 and Google Cloud Storage
// would support this.
func IsVirtualHostSupported(endpointURL url.URL, bucketName string) bool {
if endpointURL == sentinelURL {
return false
}
// bucketName can be valid but '.' in the hostname will fail SSL
// certificate validation. So do not use host-style for such buckets.
if endpointURL.Scheme == "https" && strings.Contains(bucketName, ".") {
return false
}
// Return true for all other cases
return IsAmazonEndpoint(endpointURL) || IsGoogleEndpoint(endpointURL)
}
// IsAmazonEndpoint - Match if it is exactly Amazon S3 endpoint.
func IsAmazonEndpoint(endpointURL url.URL) bool {
if IsAmazonChinaEndpoint(endpointURL) {
return true
}
return endpointURL.Host == "s3.amazonaws.com"
}
// IsAmazonChinaEndpoint - Match if it is exactly Amazon S3 China endpoint.
// Customers who wish to use the new Beijing Region are required
// to sign up for a separate set of account credentials unique to
// the China (Beijing) Region. Customers with existing AWS credentials
// will not be able to access resources in the new Region, and vice versa.
// For more info https://aws.amazon.com/about-aws/whats-new/2013/12/18/announcing-the-aws-china-beijing-region/
func IsAmazonChinaEndpoint(endpointURL url.URL) bool {
if endpointURL == sentinelURL {
return false
}
return endpointURL.Host == "s3.cn-north-1.amazonaws.com.cn"
}
// IsGoogleEndpoint - Match if it is exactly Google cloud storage endpoint.
func IsGoogleEndpoint(endpointURL url.URL) bool {
if endpointURL == sentinelURL {
return false
}
return endpointURL.Host == "storage.googleapis.com"
}
// Expects ascii encoded strings - from output of urlEncodePath
func percentEncodeSlash(s string) string {
return strings.Replace(s, "/", "%2F", -1)
}
// QueryEncode - encodes query values in their URL encoded form. In
// addition to the percent encoding performed by urlEncodePath() used
// here, it also percent encodes '/' (forward slash)
func QueryEncode(v url.Values) string {
if v == nil {
return ""
}
var buf bytes.Buffer
keys := make([]string, 0, len(v))
for k := range v {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
vs := v[k]
prefix := percentEncodeSlash(EncodePath(k)) + "="
for _, v := range vs {
if buf.Len() > 0 {
buf.WriteByte('&')
}
buf.WriteString(prefix)
buf.WriteString(percentEncodeSlash(EncodePath(v)))
}
}
return buf.String()
}
// if object matches reserved string, no need to encode them
var reservedObjectNames = regexp.MustCompile("^[a-zA-Z0-9-_.~/]+$")
// EncodePath encode the strings from UTF-8 byte representations to HTML hex escape sequences
//
// This is necessary since regular url.Parse() and url.Encode() functions do not support UTF-8
// non english characters cannot be parsed due to the nature in which url.Encode() is written
//
// This function on the other hand is a direct replacement for url.Encode() technique to support
// pretty much every UTF-8 character.
func EncodePath(pathName string) string {
if reservedObjectNames.MatchString(pathName) {
return pathName
}
var encodedPathname string
for _, s := range pathName {
if 'A' <= s && s <= 'Z' || 'a' <= s && s <= 'z' || '0' <= s && s <= '9' { // §2.3 Unreserved characters (mark)
encodedPathname = encodedPathname + string(s)
continue
}
switch s {
case '-', '_', '.', '~', '/': // §2.3 Unreserved characters (mark)
encodedPathname = encodedPathname + string(s)
continue
default:
len := utf8.RuneLen(s)
if len < 0 {
// if utf8 cannot convert return the same string as is
return pathName
}
u := make([]byte, len)
utf8.EncodeRune(u, s)
for _, r := range u {
hex := hex.EncodeToString([]byte{r})
encodedPathname = encodedPathname + "%" + strings.ToUpper(hex)
}
}
}
return encodedPathname
}

View File

@ -0,0 +1,284 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage (C) 2015, 2016 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3utils
import (
"net/url"
"testing"
)
// Tests for 'isValidDomain(host string) bool'.
func TestIsValidDomain(t *testing.T) {
testCases := []struct {
// Input.
host string
// Expected result.
result bool
}{
{"s3.amazonaws.com", true},
{"s3.cn-north-1.amazonaws.com.cn", true},
{"s3.amazonaws.com_", false},
{"%$$$", false},
{"s3.amz.test.com", true},
{"s3.%%", false},
{"localhost", true},
{"-localhost", false},
{"", false},
{"\n \t", false},
{" ", false},
}
for i, testCase := range testCases {
result := IsValidDomain(testCase.host)
if testCase.result != result {
t.Errorf("Test %d: Expected isValidDomain test to be '%v', but found '%v' instead", i+1, testCase.result, result)
}
}
}
// Tests validate IP address validator.
func TestIsValidIP(t *testing.T) {
testCases := []struct {
// Input.
ip string
// Expected result.
result bool
}{
{"192.168.1.1", true},
{"192.168.1", false},
{"192.168.1.1.1", false},
{"-192.168.1.1", false},
{"260.192.1.1", false},
}
for i, testCase := range testCases {
result := IsValidIP(testCase.ip)
if testCase.result != result {
t.Errorf("Test %d: Expected isValidIP to be '%v' for input \"%s\", but found it to be '%v' instead", i+1, testCase.result, testCase.ip, result)
}
}
}
// Tests validate virtual host validator.
func TestIsVirtualHostSupported(t *testing.T) {
testCases := []struct {
url string
bucket string
// Expeceted result.
result bool
}{
{"https://s3.amazonaws.com", "my-bucket", true},
{"https://s3.cn-north-1.amazonaws.com.cn", "my-bucket", true},
{"https://s3.amazonaws.com", "my-bucket.", false},
{"https://amazons3.amazonaws.com", "my-bucket.", false},
{"https://storage.googleapis.com/", "my-bucket", true},
{"https://mystorage.googleapis.com/", "my-bucket", false},
}
for i, testCase := range testCases {
u, err := url.Parse(testCase.url)
if err != nil {
t.Errorf("Test %d: Expected to pass, but failed with: <ERROR> %s", i+1, err)
}
result := IsVirtualHostSupported(*u, testCase.bucket)
if testCase.result != result {
t.Errorf("Test %d: Expected isVirtualHostSupported to be '%v' for input url \"%s\" and bucket \"%s\", but found it to be '%v' instead", i+1, testCase.result, testCase.url, testCase.bucket, result)
}
}
}
// Tests validate Amazon endpoint validator.
func TestIsAmazonEndpoint(t *testing.T) {
testCases := []struct {
url string
// Expected result.
result bool
}{
{"https://192.168.1.1", false},
{"192.168.1.1", false},
{"http://storage.googleapis.com", false},
{"https://storage.googleapis.com", false},
{"storage.googleapis.com", false},
{"s3.amazonaws.com", false},
{"https://amazons3.amazonaws.com", false},
{"-192.168.1.1", false},
{"260.192.1.1", false},
// valid inputs.
{"https://s3.amazonaws.com", true},
{"https://s3.cn-north-1.amazonaws.com.cn", true},
}
for i, testCase := range testCases {
u, err := url.Parse(testCase.url)
if err != nil {
t.Errorf("Test %d: Expected to pass, but failed with: <ERROR> %s", i+1, err)
}
result := IsAmazonEndpoint(*u)
if testCase.result != result {
t.Errorf("Test %d: Expected isAmazonEndpoint to be '%v' for input \"%s\", but found it to be '%v' instead", i+1, testCase.result, testCase.url, result)
}
}
}
// Tests validate Amazon S3 China endpoint validator.
func TestIsAmazonChinaEndpoint(t *testing.T) {
testCases := []struct {
url string
// Expected result.
result bool
}{
{"https://192.168.1.1", false},
{"192.168.1.1", false},
{"http://storage.googleapis.com", false},
{"https://storage.googleapis.com", false},
{"storage.googleapis.com", false},
{"s3.amazonaws.com", false},
{"https://amazons3.amazonaws.com", false},
{"-192.168.1.1", false},
{"260.192.1.1", false},
// s3.amazonaws.com is not a valid Amazon S3 China end point.
{"https://s3.amazonaws.com", false},
// valid input.
{"https://s3.cn-north-1.amazonaws.com.cn", true},
}
for i, testCase := range testCases {
u, err := url.Parse(testCase.url)
if err != nil {
t.Errorf("Test %d: Expected to pass, but failed with: <ERROR> %s", i+1, err)
}
result := IsAmazonChinaEndpoint(*u)
if testCase.result != result {
t.Errorf("Test %d: Expected isAmazonEndpoint to be '%v' for input \"%s\", but found it to be '%v' instead", i+1, testCase.result, testCase.url, result)
}
}
}
// Tests validate Google Cloud end point validator.
func TestIsGoogleEndpoint(t *testing.T) {
testCases := []struct {
url string
// Expected result.
result bool
}{
{"192.168.1.1", false},
{"https://192.168.1.1", false},
{"s3.amazonaws.com", false},
{"http://s3.amazonaws.com", false},
{"https://s3.amazonaws.com", false},
{"https://s3.cn-north-1.amazonaws.com.cn", false},
{"-192.168.1.1", false},
{"260.192.1.1", false},
// valid inputs.
{"http://storage.googleapis.com", true},
{"https://storage.googleapis.com", true},
}
for i, testCase := range testCases {
u, err := url.Parse(testCase.url)
if err != nil {
t.Errorf("Test %d: Expected to pass, but failed with: <ERROR> %s", i+1, err)
}
result := IsGoogleEndpoint(*u)
if testCase.result != result {
t.Errorf("Test %d: Expected isGoogleEndpoint to be '%v' for input \"%s\", but found it to be '%v' instead", i+1, testCase.result, testCase.url, result)
}
}
}
func TestPercentEncodeSlash(t *testing.T) {
testCases := []struct {
input string
output string
}{
{"test123", "test123"},
{"abc,+_1", "abc,+_1"},
{"%40prefix=test%40123", "%40prefix=test%40123"},
{"key1=val1/val2", "key1=val1%2Fval2"},
{"%40prefix=test%40123/", "%40prefix=test%40123%2F"},
}
for i, testCase := range testCases {
receivedOutput := percentEncodeSlash(testCase.input)
if testCase.output != receivedOutput {
t.Errorf(
"Test %d: Input: \"%s\" --> Expected percentEncodeSlash to return \"%s\", but it returned \"%s\" instead!",
i+1, testCase.input, testCase.output,
receivedOutput,
)
}
}
}
// Tests validate the query encoder.
func TestQueryEncode(t *testing.T) {
testCases := []struct {
queryKey string
valueToEncode []string
// Expected result.
result string
}{
{"prefix", []string{"test@123", "test@456"}, "prefix=test%40123&prefix=test%40456"},
{"@prefix", []string{"test@123"}, "%40prefix=test%40123"},
{"@prefix", []string{"a/b/c/"}, "%40prefix=a%2Fb%2Fc%2F"},
{"prefix", []string{"test#123"}, "prefix=test%23123"},
{"prefix#", []string{"test#123"}, "prefix%23=test%23123"},
{"prefix", []string{"test123"}, "prefix=test123"},
{"prefix", []string{"test本語123", "test123"}, "prefix=test%E6%9C%AC%E8%AA%9E123&prefix=test123"},
}
for i, testCase := range testCases {
urlValues := make(url.Values)
for _, valueToEncode := range testCase.valueToEncode {
urlValues.Add(testCase.queryKey, valueToEncode)
}
result := QueryEncode(urlValues)
if testCase.result != result {
t.Errorf("Test %d: Expected queryEncode result to be \"%s\", but found it to be \"%s\" instead", i+1, testCase.result, result)
}
}
}
// Tests validate the URL path encoder.
func TestEncodePath(t *testing.T) {
testCases := []struct {
// Input.
inputStr string
// Expected result.
result string
}{
{"thisisthe%url", "thisisthe%25url"},
{"本語", "%E6%9C%AC%E8%AA%9E"},
{"本語.1", "%E6%9C%AC%E8%AA%9E.1"},
{">123", "%3E123"},
{"myurl#link", "myurl%23link"},
{"space in url", "space%20in%20url"},
{"url+path", "url%2Bpath"},
}
for i, testCase := range testCases {
result := EncodePath(testCase.inputStr)
if testCase.result != result {
t.Errorf("Test %d: Expected queryEncode result to be \"%s\", but found it to be \"%s\" instead", i+1, testCase.result, result)
}
}
}

View File

@ -25,8 +25,8 @@ import (
// StringSet - uses map as set of strings.
type StringSet map[string]struct{}
// keys - returns StringSet keys.
func (set StringSet) keys() []string {
// ToSlice - returns StringSet as string slice.
func (set StringSet) ToSlice() []string {
keys := make([]string, 0, len(set))
for k := range set {
keys = append(keys, k)
@ -141,7 +141,7 @@ func (set StringSet) Union(sset StringSet) StringSet {
// MarshalJSON - converts to JSON data.
func (set StringSet) MarshalJSON() ([]byte, error) {
return json.Marshal(set.keys())
return json.Marshal(set.ToSlice())
}
// UnmarshalJSON - parses JSON data and creates new set with it.
@ -169,7 +169,7 @@ func (set *StringSet) UnmarshalJSON(data []byte) error {
// String - returns printable string of the set.
func (set StringSet) String() string {
return fmt.Sprintf("%s", set.keys())
return fmt.Sprintf("%s", set.ToSlice())
}
// NewStringSet - creates new string set.

View File

@ -17,6 +17,7 @@
package set
import (
"fmt"
"strings"
"testing"
)
@ -320,3 +321,27 @@ func TestStringSetString(t *testing.T) {
}
}
}
// StringSet.ToSlice() is called with series of cases for valid and erroneous inputs and the result is validated.
func TestStringSetToSlice(t *testing.T) {
testCases := []struct {
set StringSet
expectedResult string
}{
// Test empty set.
{NewStringSet(), `[]`},
// Test set with empty value.
{CreateStringSet(""), `[]`},
// Test set with value.
{CreateStringSet("foo"), `[foo]`},
// Test set with value.
{CreateStringSet("foo", "bar"), `[bar foo]`},
}
for _, testCase := range testCases {
sslice := testCase.set.ToSlice()
if str := fmt.Sprintf("%s", sslice); str != testCase.expectedResult {
t.Fatalf("expected: %s, got: %s", testCase.expectedResult, str)
}
}
}

View File

@ -149,6 +149,24 @@ func (p *PostPolicy) SetContentLengthRange(min, max int64) error {
return nil
}
// SetSuccessStatusAction - Sets the status success code of the object for this policy
// based upload.
func (p *PostPolicy) SetSuccessStatusAction(status string) error {
if strings.TrimSpace(status) == "" || status == "" {
return ErrInvalidArgument("Status is empty")
}
policyCond := policyCondition{
matchType: "eq",
condition: "$success_action_status",
value: status,
}
if err := p.addNewPolicy(policyCond); err != nil {
return err
}
p.formData["success_action_status"] = status
return nil
}
// addNewPolicy - internal helper to validate adding new policies.
func (p *PostPolicy) addNewPolicy(policyCond policyCondition) error {
if policyCond.matchType == "" || policyCond.condition == "" || policyCond.value == "" {

View File

@ -0,0 +1,52 @@
package minio
import "time"
// newRetryTimerContinous creates a timer with exponentially increasing delays forever.
func (c Client) newRetryTimerContinous(unit time.Duration, cap time.Duration, jitter float64, doneCh chan struct{}) <-chan int {
attemptCh := make(chan int)
// normalize jitter to the range [0, 1.0]
if jitter < NoJitter {
jitter = NoJitter
}
if jitter > MaxJitter {
jitter = MaxJitter
}
// computes the exponential backoff duration according to
// https://www.awsarchitectureblog.com/2015/03/backoff.html
exponentialBackoffWait := func(attempt int) time.Duration {
// 1<<uint(attempt) below could overflow, so limit the value of attempt
maxAttempt := 30
if attempt > maxAttempt {
attempt = maxAttempt
}
//sleep = random_between(0, min(cap, base * 2 ** attempt))
sleep := unit * time.Duration(1<<uint(attempt))
if sleep > cap {
sleep = cap
}
if jitter != NoJitter {
sleep -= time.Duration(c.random.Float64() * float64(sleep) * jitter)
}
return sleep
}
go func() {
defer close(attemptCh)
var nextBackoff int
for {
select {
// Attempts starts.
case attemptCh <- nextBackoff:
nextBackoff++
case <-doneCh:
// Stop the routine.
return
}
time.Sleep(exponentialBackoffWait(nextBackoff))
}
}()
return attemptCh
}

View File

@ -20,9 +20,12 @@ package minio
// "cn-north-1" adds support for AWS China.
var awsS3EndpointMap = map[string]string{
"us-east-1": "s3.amazonaws.com",
"us-east-2": "s3-us-east-2.amazonaws.com",
"us-west-2": "s3-us-west-2.amazonaws.com",
"us-west-1": "s3-us-west-1.amazonaws.com",
"ca-central-1": "s3.ca-central-1.amazonaws.com",
"eu-west-1": "s3-eu-west-1.amazonaws.com",
"eu-west-2": "s3-eu-west-2.amazonaws.com",
"eu-central-1": "s3-eu-central-1.amazonaws.com",
"ap-south-1": "s3-ap-south-1.amazonaws.com",
"ap-southeast-1": "s3-ap-southeast-1.amazonaws.com",

View File

@ -21,6 +21,7 @@ import (
"encoding/xml"
"io/ioutil"
"net/http"
"strconv"
)
// Contains common used utilities for tests.
@ -62,3 +63,12 @@ func encodeResponse(response interface{}) []byte {
encode.Encode(response)
return bytesBuffer.Bytes()
}
// Convert string to bool and always return true if any error
func mustParseBool(str string) bool {
b, err := strconv.ParseBool(str)
if err != nil {
return true
}
return b
}

View File

@ -17,11 +17,8 @@
package minio
import (
"bytes"
"crypto/hmac"
"crypto/md5"
"crypto/sha256"
"encoding/hex"
"encoding/xml"
"io"
"io/ioutil"
@ -29,10 +26,11 @@ import (
"net/http"
"net/url"
"regexp"
"sort"
"strings"
"time"
"unicode/utf8"
"github.com/minio/minio-go/pkg/s3utils"
)
// xmlDecoder provide decoded value in xml.
@ -55,13 +53,6 @@ func sumMD5(data []byte) []byte {
return hash.Sum(nil)
}
// sumHMAC calculate hmac between two input byte array.
func sumHMAC(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}
// getEndpointURL - construct a new endpoint.
func getEndpointURL(endpoint string, secure bool) (*url.URL, error) {
if strings.Contains(endpoint, ":") {
@ -69,12 +60,12 @@ func getEndpointURL(endpoint string, secure bool) (*url.URL, error) {
if err != nil {
return nil, err
}
if !isValidIP(host) && !isValidDomain(host) {
if !s3utils.IsValidIP(host) && !s3utils.IsValidDomain(host) {
msg := "Endpoint: " + endpoint + " does not follow ip address or domain name standards."
return nil, ErrInvalidArgument(msg)
}
} else {
if !isValidIP(endpoint) && !isValidDomain(endpoint) {
if !s3utils.IsValidIP(endpoint) && !s3utils.IsValidDomain(endpoint) {
msg := "Endpoint: " + endpoint + " does not follow ip address or domain name standards."
return nil, ErrInvalidArgument(msg)
}
@ -93,45 +84,12 @@ func getEndpointURL(endpoint string, secure bool) (*url.URL, error) {
}
// Validate incoming endpoint URL.
if err := isValidEndpointURL(endpointURL.String()); err != nil {
if err := isValidEndpointURL(*endpointURL); err != nil {
return nil, err
}
return endpointURL, nil
}
// isValidDomain validates if input string is a valid domain name.
func isValidDomain(host string) bool {
// See RFC 1035, RFC 3696.
host = strings.TrimSpace(host)
if len(host) == 0 || len(host) > 255 {
return false
}
// host cannot start or end with "-"
if host[len(host)-1:] == "-" || host[:1] == "-" {
return false
}
// host cannot start or end with "_"
if host[len(host)-1:] == "_" || host[:1] == "_" {
return false
}
// host cannot start or end with a "."
if host[len(host)-1:] == "." || host[:1] == "." {
return false
}
// All non alphanumeric characters are invalid.
if strings.ContainsAny(host, "`~!@#$%^&*()+={}[]|\\\"';:><?/") {
return false
}
// No need to regexp match, since the list is non-exhaustive.
// We let it valid and fail later.
return true
}
// isValidIP parses input string for ip address validity.
func isValidIP(ip string) bool {
return net.ParseIP(ip) != nil
}
// closeResponse close non nil response with any response Body.
// convenient wrapper to drain any remaining data on response body.
//
@ -152,92 +110,24 @@ func closeResponse(resp *http.Response) {
}
}
// isVirtualHostSupported - verifies if bucketName can be part of
// virtual host. Currently only Amazon S3 and Google Cloud Storage
// would support this.
func isVirtualHostSupported(endpointURL string, bucketName string) bool {
url, err := url.Parse(endpointURL)
if err != nil {
return false
}
// bucketName can be valid but '.' in the hostname will fail SSL
// certificate validation. So do not use host-style for such buckets.
if url.Scheme == "https" && strings.Contains(bucketName, ".") {
return false
}
// Return true for all other cases
return isAmazonEndpoint(endpointURL) || isGoogleEndpoint(endpointURL)
}
// Match if it is exactly Amazon S3 endpoint.
func isAmazonEndpoint(endpointURL string) bool {
if isAmazonChinaEndpoint(endpointURL) {
return true
}
url, err := url.Parse(endpointURL)
if err != nil {
return false
}
if url.Host == "s3.amazonaws.com" {
return true
}
return false
}
// Match if it is exactly Amazon S3 China endpoint.
// Customers who wish to use the new Beijing Region are required
// to sign up for a separate set of account credentials unique to
// the China (Beijing) Region. Customers with existing AWS credentials
// will not be able to access resources in the new Region, and vice versa.
// For more info https://aws.amazon.com/about-aws/whats-new/2013/12/18/announcing-the-aws-china-beijing-region/
func isAmazonChinaEndpoint(endpointURL string) bool {
if endpointURL == "" {
return false
}
url, err := url.Parse(endpointURL)
if err != nil {
return false
}
if url.Host == "s3.cn-north-1.amazonaws.com.cn" {
return true
}
return false
}
// Match if it is exactly Google cloud storage endpoint.
func isGoogleEndpoint(endpointURL string) bool {
if endpointURL == "" {
return false
}
url, err := url.Parse(endpointURL)
if err != nil {
return false
}
if url.Host == "storage.googleapis.com" {
return true
}
return false
}
// Sentinel URL is the default url value which is invalid.
var sentinelURL = url.URL{}
// Verify if input endpoint URL is valid.
func isValidEndpointURL(endpointURL string) error {
if endpointURL == "" {
func isValidEndpointURL(endpointURL url.URL) error {
if endpointURL == sentinelURL {
return ErrInvalidArgument("Endpoint url cannot be empty.")
}
url, err := url.Parse(endpointURL)
if err != nil {
if endpointURL.Path != "/" && endpointURL.Path != "" {
return ErrInvalidArgument("Endpoint url cannot have fully qualified paths.")
}
if url.Path != "/" && url.Path != "" {
return ErrInvalidArgument("Endpoint url cannot have fully qualified paths.")
}
if strings.Contains(endpointURL, ".amazonaws.com") {
if !isAmazonEndpoint(endpointURL) {
if strings.Contains(endpointURL.Host, ".amazonaws.com") {
if !s3utils.IsAmazonEndpoint(endpointURL) {
return ErrInvalidArgument("Amazon S3 endpoint should be 's3.amazonaws.com'.")
}
}
if strings.Contains(endpointURL, ".googleapis.com") {
if !isGoogleEndpoint(endpointURL) {
if strings.Contains(endpointURL.Host, ".googleapis.com") {
if !s3utils.IsGoogleEndpoint(endpointURL) {
return ErrInvalidArgument("Google Cloud Storage endpoint should be 'storage.googleapis.com'.")
}
}
@ -260,6 +150,9 @@ func isValidExpiry(expires time.Duration) error {
// style requests instead for such buckets.
var validBucketName = regexp.MustCompile(`^[a-z0-9][a-z0-9\.\-]{1,61}[a-z0-9]$`)
// Invalid bucket name with double dot.
var invalidDotBucketName = regexp.MustCompile(`\.\.`)
// isValidBucketName - verify bucket name in accordance with
// - http://docs.aws.amazon.com/AmazonS3/latest/dev/UsingBucket.html
func isValidBucketName(bucketName string) error {
@ -275,7 +168,7 @@ func isValidBucketName(bucketName string) error {
if bucketName[0] == '.' || bucketName[len(bucketName)-1] == '.' {
return ErrInvalidBucketName("Bucket name cannot start or end with a '.' dot.")
}
if match, _ := regexp.MatchString("\\.\\.", bucketName); match {
if invalidDotBucketName.MatchString(bucketName) {
return ErrInvalidBucketName("Bucket name cannot have successive periods.")
}
if !validBucketName.MatchString(bucketName) {
@ -310,67 +203,25 @@ func isValidObjectPrefix(objectPrefix string) error {
return nil
}
// queryEncode - encodes query values in their URL encoded form.
func queryEncode(v url.Values) string {
if v == nil {
return ""
// make a copy of http.Header
func cloneHeader(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
var buf bytes.Buffer
keys := make([]string, 0, len(v))
for k := range v {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
vs := v[k]
prefix := urlEncodePath(k) + "="
for _, v := range vs {
if buf.Len() > 0 {
buf.WriteByte('&')
}
buf.WriteString(prefix)
buf.WriteString(urlEncodePath(v))
}
}
return buf.String()
return h2
}
// urlEncodePath encode the strings from UTF-8 byte representations to HTML hex escape sequences
//
// This is necessary since regular url.Parse() and url.Encode() functions do not support UTF-8
// non english characters cannot be parsed due to the nature in which url.Encode() is written
//
// This function on the other hand is a direct replacement for url.Encode() technique to support
// pretty much every UTF-8 character.
func urlEncodePath(pathName string) string {
// if object matches reserved string, no need to encode them
reservedNames := regexp.MustCompile("^[a-zA-Z0-9-_.~/]+$")
if reservedNames.MatchString(pathName) {
return pathName
// Filter relevant response headers from
// the HEAD, GET http response. The function takes
// a list of headers which are filtered out and
// returned as a new http header.
func filterHeader(header http.Header, filterKeys []string) (filteredHeader http.Header) {
filteredHeader = cloneHeader(header)
for _, key := range filterKeys {
filteredHeader.Del(key)
}
var encodedPathname string
for _, s := range pathName {
if 'A' <= s && s <= 'Z' || 'a' <= s && s <= 'z' || '0' <= s && s <= '9' { // §2.3 Unreserved characters (mark)
encodedPathname = encodedPathname + string(s)
continue
}
switch s {
case '-', '_', '.', '~', '/': // §2.3 Unreserved characters (mark)
encodedPathname = encodedPathname + string(s)
continue
default:
len := utf8.RuneLen(s)
if len < 0 {
// if utf8 cannot convert return the same string as is
return pathName
}
u := make([]byte, len)
utf8.EncodeRune(u, s)
for _, r := range u {
hex := hex.EncodeToString([]byte{r})
encodedPathname = encodedPathname + "%" + strings.ToUpper(hex)
}
}
}
return encodedPathname
return filteredHeader
}

View File

@ -17,11 +17,27 @@ package minio
import (
"fmt"
"net/http"
"net/url"
"testing"
"time"
)
// Tests filter header function by filtering out
// some custom header keys.
func TestFilterHeader(t *testing.T) {
header := http.Header{}
header.Set("Content-Type", "binary/octet-stream")
header.Set("Content-Encoding", "gzip")
newHeader := filterHeader(header, []string{"Content-Type"})
if len(newHeader) > 1 {
t.Fatalf("Unexpected size of the returned header, should be 1, got %d", len(newHeader))
}
if newHeader.Get("Content-Encoding") != "gzip" {
t.Fatalf("Unexpected content-encoding value, expected 'gzip', got %s", newHeader.Get("Content-Encoding"))
}
}
// Tests for 'getEndpointURL(endpoint string, inSecure bool)'.
func TestGetEndpointURL(t *testing.T) {
testCases := []struct {
@ -74,35 +90,6 @@ func TestGetEndpointURL(t *testing.T) {
}
}
// Tests for 'isValidDomain(host string) bool'.
func TestIsValidDomain(t *testing.T) {
testCases := []struct {
// Input.
host string
// Expected result.
result bool
}{
{"s3.amazonaws.com", true},
{"s3.cn-north-1.amazonaws.com.cn", true},
{"s3.amazonaws.com_", false},
{"%$$$", false},
{"s3.amz.test.com", true},
{"s3.%%", false},
{"localhost", true},
{"-localhost", false},
{"", false},
{"\n \t", false},
{" ", false},
}
for i, testCase := range testCases {
result := isValidDomain(testCase.host)
if testCase.result != result {
t.Errorf("Test %d: Expected isValidDomain test to be '%v', but found '%v' instead", i+1, testCase.result, result)
}
}
}
// Tests validate end point validator.
func TestIsValidEndpointURL(t *testing.T) {
testCases := []struct {
@ -125,161 +112,33 @@ func TestIsValidEndpointURL(t *testing.T) {
}
for i, testCase := range testCases {
err := isValidEndpointURL(testCase.url)
var u url.URL
if testCase.url == "" {
u = sentinelURL
} else {
u1, err := url.Parse(testCase.url)
if err != nil {
t.Errorf("Test %d: Expected to pass, but failed with: <ERROR> %s", i+1, err)
}
u = *u1
}
err := isValidEndpointURL(u)
if err != nil && testCase.shouldPass {
t.Errorf("Test %d: Expected to pass, but failed with: <ERROR> %s", i+1, err.Error())
t.Errorf("Test %d: Expected to pass, but failed with: <ERROR> %s", i+1, err)
}
if err == nil && !testCase.shouldPass {
t.Errorf("Test %d: Expected to fail with <ERROR> \"%s\", but passed instead", i+1, testCase.err.Error())
t.Errorf("Test %d: Expected to fail with <ERROR> \"%s\", but passed instead", i+1, testCase.err)
}
// Failed as expected, but does it fail for the expected reason.
if err != nil && !testCase.shouldPass {
if err.Error() != testCase.err.Error() {
t.Errorf("Test %d: Expected to fail with error \"%s\", but instead failed with error \"%s\" instead", i+1, testCase.err.Error(), err.Error())
t.Errorf("Test %d: Expected to fail with error \"%s\", but instead failed with error \"%s\" instead", i+1, testCase.err, err)
}
}
}
}
// Tests validate IP address validator.
func TestIsValidIP(t *testing.T) {
testCases := []struct {
// Input.
ip string
// Expected result.
result bool
}{
{"192.168.1.1", true},
{"192.168.1", false},
{"192.168.1.1.1", false},
{"-192.168.1.1", false},
{"260.192.1.1", false},
}
for i, testCase := range testCases {
result := isValidIP(testCase.ip)
if testCase.result != result {
t.Errorf("Test %d: Expected isValidIP to be '%v' for input \"%s\", but found it to be '%v' instead", i+1, testCase.result, testCase.ip, result)
}
}
}
// Tests validate virtual host validator.
func TestIsVirtualHostSupported(t *testing.T) {
testCases := []struct {
url string
bucket string
// Expeceted result.
result bool
}{
{"https://s3.amazonaws.com", "my-bucket", true},
{"https://s3.cn-north-1.amazonaws.com.cn", "my-bucket", true},
{"https://s3.amazonaws.com", "my-bucket.", false},
{"https://amazons3.amazonaws.com", "my-bucket.", false},
{"https://storage.googleapis.com/", "my-bucket", true},
{"https://mystorage.googleapis.com/", "my-bucket", false},
}
for i, testCase := range testCases {
result := isVirtualHostSupported(testCase.url, testCase.bucket)
if testCase.result != result {
t.Errorf("Test %d: Expected isVirtualHostSupported to be '%v' for input url \"%s\" and bucket \"%s\", but found it to be '%v' instead", i+1, testCase.result, testCase.url, testCase.bucket, result)
}
}
}
// Tests validate Amazon endpoint validator.
func TestIsAmazonEndpoint(t *testing.T) {
testCases := []struct {
url string
// Expected result.
result bool
}{
{"https://192.168.1.1", false},
{"192.168.1.1", false},
{"http://storage.googleapis.com", false},
{"https://storage.googleapis.com", false},
{"storage.googleapis.com", false},
{"s3.amazonaws.com", false},
{"https://amazons3.amazonaws.com", false},
{"-192.168.1.1", false},
{"260.192.1.1", false},
// valid inputs.
{"https://s3.amazonaws.com", true},
{"https://s3.cn-north-1.amazonaws.com.cn", true},
}
for i, testCase := range testCases {
result := isAmazonEndpoint(testCase.url)
if testCase.result != result {
t.Errorf("Test %d: Expected isAmazonEndpoint to be '%v' for input \"%s\", but found it to be '%v' instead", i+1, testCase.result, testCase.url, result)
}
}
}
// Tests validate Amazon S3 China endpoint validator.
func TestIsAmazonChinaEndpoint(t *testing.T) {
testCases := []struct {
url string
// Expected result.
result bool
}{
{"https://192.168.1.1", false},
{"192.168.1.1", false},
{"http://storage.googleapis.com", false},
{"https://storage.googleapis.com", false},
{"storage.googleapis.com", false},
{"s3.amazonaws.com", false},
{"https://amazons3.amazonaws.com", false},
{"-192.168.1.1", false},
{"260.192.1.1", false},
// s3.amazonaws.com is not a valid Amazon S3 China end point.
{"https://s3.amazonaws.com", false},
// valid input.
{"https://s3.cn-north-1.amazonaws.com.cn", true},
}
for i, testCase := range testCases {
result := isAmazonChinaEndpoint(testCase.url)
if testCase.result != result {
t.Errorf("Test %d: Expected isAmazonEndpoint to be '%v' for input \"%s\", but found it to be '%v' instead", i+1, testCase.result, testCase.url, result)
}
}
}
// Tests validate Google Cloud end point validator.
func TestIsGoogleEndpoint(t *testing.T) {
testCases := []struct {
url string
// Expected result.
result bool
}{
{"192.168.1.1", false},
{"https://192.168.1.1", false},
{"s3.amazonaws.com", false},
{"http://s3.amazonaws.com", false},
{"https://s3.amazonaws.com", false},
{"https://s3.cn-north-1.amazonaws.com.cn", false},
{"-192.168.1.1", false},
{"260.192.1.1", false},
// valid inputs.
{"http://storage.googleapis.com", true},
{"https://storage.googleapis.com", true},
}
for i, testCase := range testCases {
result := isGoogleEndpoint(testCase.url)
if testCase.result != result {
t.Errorf("Test %d: Expected isGoogleEndpoint to be '%v' for input \"%s\", but found it to be '%v' instead", i+1, testCase.result, testCase.url, result)
}
}
}
// Tests validate the expiry time validator.
func TestIsValidExpiry(t *testing.T) {
testCases := []struct {
@ -355,56 +214,3 @@ func TestIsValidBucketName(t *testing.T) {
}
}
// Tests validate the query encoder.
func TestQueryEncode(t *testing.T) {
testCases := []struct {
queryKey string
valueToEncode []string
// Expected result.
result string
}{
{"prefix", []string{"test@123", "test@456"}, "prefix=test%40123&prefix=test%40456"},
{"@prefix", []string{"test@123"}, "%40prefix=test%40123"},
{"prefix", []string{"test#123"}, "prefix=test%23123"},
{"prefix#", []string{"test#123"}, "prefix%23=test%23123"},
{"prefix", []string{"test123"}, "prefix=test123"},
{"prefix", []string{"test本語123", "test123"}, "prefix=test%E6%9C%AC%E8%AA%9E123&prefix=test123"},
}
for i, testCase := range testCases {
urlValues := make(url.Values)
for _, valueToEncode := range testCase.valueToEncode {
urlValues.Add(testCase.queryKey, valueToEncode)
}
result := queryEncode(urlValues)
if testCase.result != result {
t.Errorf("Test %d: Expected queryEncode result to be \"%s\", but found it to be \"%s\" instead", i+1, testCase.result, result)
}
}
}
// Tests validate the URL path encoder.
func TestUrlEncodePath(t *testing.T) {
testCases := []struct {
// Input.
inputStr string
// Expected result.
result string
}{
{"thisisthe%url", "thisisthe%25url"},
{"本語", "%E6%9C%AC%E8%AA%9E"},
{"本語.1", "%E6%9C%AC%E8%AA%9E.1"},
{">123", "%3E123"},
{"myurl#link", "myurl%23link"},
{"space in url", "space%20in%20url"},
{"url+path", "url%2Bpath"},
}
for i, testCase := range testCases {
result := urlEncodePath(testCase.inputStr)
if testCase.result != result {
t.Errorf("Test %d: Expected queryEncode result to be \"%s\", but found it to be \"%s\" instead", i+1, testCase.result, result)
}
}
}

View File

@ -14,13 +14,18 @@
// Adding context to an error
//
// The errors.Wrap function returns a new error that adds context to the
// original error. For example
// original error by recording a stack trace at the point Wrap is called,
// and the supplied message. For example
//
// _, err := ioutil.ReadAll(r)
// if err != nil {
// return errors.Wrap(err, "read failed")
// }
//
// If additional control is required the errors.WithStack and errors.WithMessage
// functions destructure errors.Wrap into its component operations of annotating
// an error with a stack trace and an a message, respectively.
//
// Retrieving the cause of an error
//
// Using errors.Wrap constructs a stack of errors, adding context to the
@ -134,6 +139,18 @@ func (f *fundamental) Format(s fmt.State, verb rune) {
}
}
// WithStack annotates err with a stack trace at the point WithStack was called.
// If err is nil, WithStack returns nil.
func WithStack(err error) error {
if err == nil {
return nil
}
return &withStack{
err,
callers(),
}
}
type withStack struct {
error
*stack
@ -157,7 +174,8 @@ func (w *withStack) Format(s fmt.State, verb rune) {
}
}
// Wrap returns an error annotating err with message.
// Wrap returns an error annotating err with a stack trace
// at the point Wrap is called, and the supplied message.
// If err is nil, Wrap returns nil.
func Wrap(err error, message string) error {
if err == nil {
@ -173,7 +191,8 @@ func Wrap(err error, message string) error {
}
}
// Wrapf returns an error annotating err with the format specifier.
// Wrapf returns an error annotating err with a stack trace
// at the point Wrapf is call, and the format specifier.
// If err is nil, Wrapf returns nil.
func Wrapf(err error, format string, args ...interface{}) error {
if err == nil {
@ -189,6 +208,18 @@ func Wrapf(err error, format string, args ...interface{}) error {
}
}
// WithMessage annotates err with a new message.
// If err is nil, WithMessage returns nil.
func WithMessage(err error, message string) error {
if err == nil {
return nil
}
return &withMessage{
cause: err,
msg: message,
}
}
type withMessage struct {
cause error
msg string

View File

@ -84,6 +84,18 @@ func TestCause(t *testing.T) {
}, {
err: x, // return from errors.New
want: x,
}, {
WithMessage(nil, "whoops"),
nil,
}, {
WithMessage(io.EOF, "whoops"),
io.EOF,
}, {
WithStack(nil),
nil,
}, {
WithStack(io.EOF),
io.EOF,
}}
for i, tt := range tests {
@ -137,23 +149,78 @@ func TestErrorf(t *testing.T) {
}
}
func TestWithStackNil(t *testing.T) {
got := WithStack(nil)
if got != nil {
t.Errorf("WithStack(nil): got %#v, expected nil", got)
}
}
func TestWithStack(t *testing.T) {
tests := []struct {
err error
want string
}{
{io.EOF, "EOF"},
{WithStack(io.EOF), "EOF"},
}
for _, tt := range tests {
got := WithStack(tt.err).Error()
if got != tt.want {
t.Errorf("WithStack(%v): got: %v, want %v", tt.err, got, tt.want)
}
}
}
func TestWithMessageNil(t *testing.T) {
got := WithMessage(nil, "no error")
if got != nil {
t.Errorf("WithMessage(nil, \"no error\"): got %#v, expected nil", got)
}
}
func TestWithMessage(t *testing.T) {
tests := []struct {
err error
message string
want string
}{
{io.EOF, "read error", "read error: EOF"},
{WithMessage(io.EOF, "read error"), "client error", "client error: read error: EOF"},
}
for _, tt := range tests {
got := WithMessage(tt.err, tt.message).Error()
if got != tt.want {
t.Errorf("WithMessage(%v, %q): got: %q, want %q", tt.err, tt.message, got, tt.want)
}
}
}
// errors.New, etc values are not expected to be compared by value
// but the change in errors#27 made them incomparable. Assert that
// various kinds of errors have a functional equality operator, even
// if the result of that equality is always false.
func TestErrorEquality(t *testing.T) {
tests := []struct {
err1, err2 error
}{
{io.EOF, io.EOF},
{io.EOF, nil},
{io.EOF, errors.New("EOF")},
{io.EOF, New("EOF")},
{New("EOF"), New("EOF")},
{New("EOF"), Errorf("EOF")},
{New("EOF"), Wrap(io.EOF, "EOF")},
vals := []error{
nil,
io.EOF,
errors.New("EOF"),
New("EOF"),
Errorf("EOF"),
Wrap(io.EOF, "EOF"),
Wrapf(io.EOF, "EOF%d", 2),
WithMessage(nil, "whoops"),
WithMessage(io.EOF, "whoops"),
WithStack(io.EOF),
WithStack(nil),
}
for _, tt := range tests {
_ = tt.err1 == tt.err2 // mustn't panic
for i := range vals {
for j := range vals {
_ = vals[i] == vals[j] // mustn't panic
}
}
}

View File

@ -35,6 +35,59 @@ func ExampleNew_printf() {
// /home/dfc/go/src/runtime/asm_amd64.s:2059
}
func ExampleWithMessage() {
cause := errors.New("whoops")
err := errors.WithMessage(cause, "oh noes")
fmt.Println(err)
// Output: oh noes: whoops
}
func ExampleWithStack() {
cause := errors.New("whoops")
err := errors.WithStack(cause)
fmt.Println(err)
// Output: whoops
}
func ExampleWithStack_printf() {
cause := errors.New("whoops")
err := errors.WithStack(cause)
fmt.Printf("%+v", err)
// Example Output:
// whoops
// github.com/pkg/errors_test.ExampleWithStack_printf
// /home/fabstu/go/src/github.com/pkg/errors/example_test.go:55
// testing.runExample
// /usr/lib/go/src/testing/example.go:114
// testing.RunExamples
// /usr/lib/go/src/testing/example.go:38
// testing.(*M).Run
// /usr/lib/go/src/testing/testing.go:744
// main.main
// github.com/pkg/errors/_test/_testmain.go:106
// runtime.main
// /usr/lib/go/src/runtime/proc.go:183
// runtime.goexit
// /usr/lib/go/src/runtime/asm_amd64.s:2086
// github.com/pkg/errors_test.ExampleWithStack_printf
// /home/fabstu/go/src/github.com/pkg/errors/example_test.go:56
// testing.runExample
// /usr/lib/go/src/testing/example.go:114
// testing.RunExamples
// /usr/lib/go/src/testing/example.go:38
// testing.(*M).Run
// /usr/lib/go/src/testing/testing.go:744
// main.main
// github.com/pkg/errors/_test/_testmain.go:106
// runtime.main
// /usr/lib/go/src/runtime/proc.go:183
// runtime.goexit
// /usr/lib/go/src/runtime/asm_amd64.s:2086
}
func ExampleWrap() {
cause := errors.New("whoops")
err := errors.Wrap(cause, "oh noes")

View File

@ -1,6 +1,7 @@
package errors
import (
"errors"
"fmt"
"io"
"regexp"
@ -26,7 +27,7 @@ func TestFormatNew(t *testing.T) {
"%+v",
"error\n" +
"github.com/pkg/errors.TestFormatNew\n" +
"\t.+/github.com/pkg/errors/format_test.go:25",
"\t.+/github.com/pkg/errors/format_test.go:26",
}, {
New("error"),
"%q",
@ -56,7 +57,7 @@ func TestFormatErrorf(t *testing.T) {
"%+v",
"error\n" +
"github.com/pkg/errors.TestFormatErrorf\n" +
"\t.+/github.com/pkg/errors/format_test.go:55",
"\t.+/github.com/pkg/errors/format_test.go:56",
}}
for i, tt := range tests {
@ -82,7 +83,7 @@ func TestFormatWrap(t *testing.T) {
"%+v",
"error\n" +
"github.com/pkg/errors.TestFormatWrap\n" +
"\t.+/github.com/pkg/errors/format_test.go:81",
"\t.+/github.com/pkg/errors/format_test.go:82",
}, {
Wrap(io.EOF, "error"),
"%s",
@ -97,14 +98,14 @@ func TestFormatWrap(t *testing.T) {
"EOF\n" +
"error\n" +
"github.com/pkg/errors.TestFormatWrap\n" +
"\t.+/github.com/pkg/errors/format_test.go:95",
"\t.+/github.com/pkg/errors/format_test.go:96",
}, {
Wrap(Wrap(io.EOF, "error1"), "error2"),
"%+v",
"EOF\n" +
"error1\n" +
"github.com/pkg/errors.TestFormatWrap\n" +
"\t.+/github.com/pkg/errors/format_test.go:102\n",
"\t.+/github.com/pkg/errors/format_test.go:103\n",
}, {
Wrap(New("error with space"), "context"),
"%q",
@ -135,7 +136,7 @@ func TestFormatWrapf(t *testing.T) {
"EOF\n" +
"error2\n" +
"github.com/pkg/errors.TestFormatWrapf\n" +
"\t.+/github.com/pkg/errors/format_test.go:133",
"\t.+/github.com/pkg/errors/format_test.go:134",
}, {
Wrapf(New("error"), "error%d", 2),
"%s",
@ -149,7 +150,7 @@ func TestFormatWrapf(t *testing.T) {
"%+v",
"error\n" +
"github.com/pkg/errors.TestFormatWrapf\n" +
"\t.+/github.com/pkg/errors/format_test.go:148",
"\t.+/github.com/pkg/errors/format_test.go:149",
}}
for i, tt := range tests {
@ -157,16 +158,378 @@ func TestFormatWrapf(t *testing.T) {
}
}
func TestFormatWithStack(t *testing.T) {
tests := []struct {
error
format string
want []string
}{{
WithStack(io.EOF),
"%s",
[]string{"EOF"},
}, {
WithStack(io.EOF),
"%v",
[]string{"EOF"},
}, {
WithStack(io.EOF),
"%+v",
[]string{"EOF",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:175"},
}, {
WithStack(New("error")),
"%s",
[]string{"error"},
}, {
WithStack(New("error")),
"%v",
[]string{"error"},
}, {
WithStack(New("error")),
"%+v",
[]string{"error",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:189",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:189"},
}, {
WithStack(WithStack(io.EOF)),
"%+v",
[]string{"EOF",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:197",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:197"},
}, {
WithStack(WithStack(Wrapf(io.EOF, "message"))),
"%+v",
[]string{"EOF",
"message",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:205",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:205",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:205"},
}, {
WithStack(Errorf("error%d", 1)),
"%+v",
[]string{"error1",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:216",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:216"},
}}
for i, tt := range tests {
testFormatCompleteCompare(t, i, tt.error, tt.format, tt.want, true)
}
}
func TestFormatWithMessage(t *testing.T) {
tests := []struct {
error
format string
want []string
}{{
WithMessage(New("error"), "error2"),
"%s",
[]string{"error2: error"},
}, {
WithMessage(New("error"), "error2"),
"%v",
[]string{"error2: error"},
}, {
WithMessage(New("error"), "error2"),
"%+v",
[]string{
"error",
"github.com/pkg/errors.TestFormatWithMessage\n" +
"\t.+/github.com/pkg/errors/format_test.go:244",
"error2"},
}, {
WithMessage(io.EOF, "addition1"),
"%s",
[]string{"addition1: EOF"},
}, {
WithMessage(io.EOF, "addition1"),
"%v",
[]string{"addition1: EOF"},
}, {
WithMessage(io.EOF, "addition1"),
"%+v",
[]string{"EOF", "addition1"},
}, {
WithMessage(WithMessage(io.EOF, "addition1"), "addition2"),
"%v",
[]string{"addition2: addition1: EOF"},
}, {
WithMessage(WithMessage(io.EOF, "addition1"), "addition2"),
"%+v",
[]string{"EOF", "addition1", "addition2"},
}, {
Wrap(WithMessage(io.EOF, "error1"), "error2"),
"%+v",
[]string{"EOF", "error1", "error2",
"github.com/pkg/errors.TestFormatWithMessage\n" +
"\t.+/github.com/pkg/errors/format_test.go:272"},
}, {
WithMessage(Errorf("error%d", 1), "error2"),
"%+v",
[]string{"error1",
"github.com/pkg/errors.TestFormatWithMessage\n" +
"\t.+/github.com/pkg/errors/format_test.go:278",
"error2"},
}, {
WithMessage(WithStack(io.EOF), "error"),
"%+v",
[]string{
"EOF",
"github.com/pkg/errors.TestFormatWithMessage\n" +
"\t.+/github.com/pkg/errors/format_test.go:285",
"error"},
}, {
WithMessage(Wrap(WithStack(io.EOF), "inside-error"), "outside-error"),
"%+v",
[]string{
"EOF",
"github.com/pkg/errors.TestFormatWithMessage\n" +
"\t.+/github.com/pkg/errors/format_test.go:293",
"inside-error",
"github.com/pkg/errors.TestFormatWithMessage\n" +
"\t.+/github.com/pkg/errors/format_test.go:293",
"outside-error"},
}}
for i, tt := range tests {
testFormatCompleteCompare(t, i, tt.error, tt.format, tt.want, true)
}
}
func TestFormatGeneric(t *testing.T) {
starts := []struct {
err error
want []string
}{
{New("new-error"), []string{
"new-error",
"github.com/pkg/errors.TestFormatGeneric\n" +
"\t.+/github.com/pkg/errors/format_test.go:315"},
}, {Errorf("errorf-error"), []string{
"errorf-error",
"github.com/pkg/errors.TestFormatGeneric\n" +
"\t.+/github.com/pkg/errors/format_test.go:319"},
}, {errors.New("errors-new-error"), []string{
"errors-new-error"},
},
}
wrappers := []wrapper{
{
func(err error) error { return WithMessage(err, "with-message") },
[]string{"with-message"},
}, {
func(err error) error { return WithStack(err) },
[]string{
"github.com/pkg/errors.(func·002|TestFormatGeneric.func2)\n\t" +
".+/github.com/pkg/errors/format_test.go:333",
},
}, {
func(err error) error { return Wrap(err, "wrap-error") },
[]string{
"wrap-error",
"github.com/pkg/errors.(func·003|TestFormatGeneric.func3)\n\t" +
".+/github.com/pkg/errors/format_test.go:339",
},
}, {
func(err error) error { return Wrapf(err, "wrapf-error%d", 1) },
[]string{
"wrapf-error1",
"github.com/pkg/errors.(func·004|TestFormatGeneric.func4)\n\t" +
".+/github.com/pkg/errors/format_test.go:346",
},
},
}
for s := range starts {
err := starts[s].err
want := starts[s].want
testFormatCompleteCompare(t, s, err, "%+v", want, false)
testGenericRecursive(t, err, want, wrappers, 3)
}
}
func testFormatRegexp(t *testing.T, n int, arg interface{}, format, want string) {
got := fmt.Sprintf(format, arg)
lines := strings.SplitN(got, "\n", -1)
for i, w := range strings.SplitN(want, "\n", -1) {
match, err := regexp.MatchString(w, lines[i])
gotLines := strings.SplitN(got, "\n", -1)
wantLines := strings.SplitN(want, "\n", -1)
if len(wantLines) > len(gotLines) {
t.Errorf("test %d: wantLines(%d) > gotLines(%d):\n got: %q\nwant: %q", n+1, len(wantLines), len(gotLines), got, want)
return
}
for i, w := range wantLines {
match, err := regexp.MatchString(w, gotLines[i])
if err != nil {
t.Fatal(err)
}
if !match {
t.Errorf("test %d: line %d: fmt.Sprintf(%q, err): got: %q, want: %q", n+1, i+1, format, got, want)
t.Errorf("test %d: line %d: fmt.Sprintf(%q, err):\n got: %q\nwant: %q", n+1, i+1, format, got, want)
}
}
}
var stackLineR = regexp.MustCompile(`\.`)
// parseBlocks parses input into a slice, where:
// - incase entry contains a newline, its a stacktrace
// - incase entry contains no newline, its a solo line.
//
// Detecting stack boundaries only works incase the WithStack-calls are
// to be found on the same line, thats why it is optionally here.
//
// Example use:
//
// for _, e := range blocks {
// if strings.ContainsAny(e, "\n") {
// // Match as stack
// } else {
// // Match as line
// }
// }
//
func parseBlocks(input string, detectStackboundaries bool) ([]string, error) {
var blocks []string
stack := ""
wasStack := false
lines := map[string]bool{} // already found lines
for _, l := range strings.Split(input, "\n") {
isStackLine := stackLineR.MatchString(l)
switch {
case !isStackLine && wasStack:
blocks = append(blocks, stack, l)
stack = ""
lines = map[string]bool{}
case isStackLine:
if wasStack {
// Detecting two stacks after another, possible cause lines match in
// our tests due to WithStack(WithStack(io.EOF)) on same line.
if detectStackboundaries {
if lines[l] {
if len(stack) == 0 {
return nil, errors.New("len of block must not be zero here")
}
blocks = append(blocks, stack)
stack = l
lines = map[string]bool{l: true}
continue
}
}
stack = stack + "\n" + l
} else {
stack = l
}
lines[l] = true
case !isStackLine && !wasStack:
blocks = append(blocks, l)
default:
return nil, errors.New("must not happen")
}
wasStack = isStackLine
}
// Use up stack
if stack != "" {
blocks = append(blocks, stack)
}
return blocks, nil
}
func testFormatCompleteCompare(t *testing.T, n int, arg interface{}, format string, want []string, detectStackBoundaries bool) {
gotStr := fmt.Sprintf(format, arg)
got, err := parseBlocks(gotStr, detectStackBoundaries)
if err != nil {
t.Fatal(err)
}
if len(got) != len(want) {
t.Fatalf("test %d: fmt.Sprintf(%s, err) -> wrong number of blocks: got(%d) want(%d)\n got: %s\nwant: %s\ngotStr: %q",
n+1, format, len(got), len(want), prettyBlocks(got), prettyBlocks(want), gotStr)
}
for i := range got {
if strings.ContainsAny(want[i], "\n") {
// Match as stack
match, err := regexp.MatchString(want[i], got[i])
if err != nil {
t.Fatal(err)
}
if !match {
t.Fatalf("test %d: block %d: fmt.Sprintf(%q, err):\ngot:\n%q\nwant:\n%q\nall-got:\n%s\nall-want:\n%s\n",
n+1, i+1, format, got[i], want[i], prettyBlocks(got), prettyBlocks(want))
}
} else {
// Match as message
if got[i] != want[i] {
t.Fatalf("test %d: fmt.Sprintf(%s, err) at block %d got != want:\n got: %q\nwant: %q", n+1, format, i+1, got[i], want[i])
}
}
}
}
type wrapper struct {
wrap func(err error) error
want []string
}
func prettyBlocks(blocks []string, prefix ...string) string {
var out []string
for _, b := range blocks {
out = append(out, fmt.Sprintf("%v", b))
}
return " " + strings.Join(out, "\n ")
}
func testGenericRecursive(t *testing.T, beforeErr error, beforeWant []string, list []wrapper, maxDepth int) {
if len(beforeWant) == 0 {
panic("beforeWant must not be empty")
}
for _, w := range list {
if len(w.want) == 0 {
panic("want must not be empty")
}
err := w.wrap(beforeErr)
// Copy required cause append(beforeWant, ..) modified beforeWant subtly.
beforeCopy := make([]string, len(beforeWant))
copy(beforeCopy, beforeWant)
beforeWant := beforeCopy
last := len(beforeWant) - 1
var want []string
// Merge two stacks behind each other.
if strings.ContainsAny(beforeWant[last], "\n") && strings.ContainsAny(w.want[0], "\n") {
want = append(beforeWant[:last], append([]string{beforeWant[last] + "((?s).*)" + w.want[0]}, w.want[1:]...)...)
} else {
want = append(beforeWant, w.want...)
}
testFormatCompleteCompare(t, maxDepth, err, "%+v", want, false)
if maxDepth > 0 {
testGenericRecursive(t, err, want, list, maxDepth-1)
}
}
}

View File

@ -185,7 +185,7 @@ func TestStackTrace(t *testing.T) {
},
}, {
func() error { return New("ooh") }(), []string{
`github.com/pkg/errors.(func·005|TestStackTrace.func1)` +
`github.com/pkg/errors.(func·009|TestStackTrace.func1)` +
"\n\t.+/github.com/pkg/errors/stack_test.go:187", // this is the stack of New
"github.com/pkg/errors.TestStackTrace\n" +
"\t.+/github.com/pkg/errors/stack_test.go:187", // this is the stack of New's caller
@ -196,9 +196,9 @@ func TestStackTrace(t *testing.T) {
return Errorf("hello %s", fmt.Sprintf("world"))
}()
}()), []string{
`github.com/pkg/errors.(func·006|TestStackTrace.func2.1)` +
`github.com/pkg/errors.(func·010|TestStackTrace.func2.1)` +
"\n\t.+/github.com/pkg/errors/stack_test.go:196", // this is the stack of Errorf
`github.com/pkg/errors.(func·007|TestStackTrace.func2)` +
`github.com/pkg/errors.(func·011|TestStackTrace.func2)` +
"\n\t.+/github.com/pkg/errors/stack_test.go:197", // this is the stack of Errorf's caller
"github.com/pkg/errors.TestStackTrace\n" +
"\t.+/github.com/pkg/errors/stack_test.go:198", // this is the stack of Errorf's caller's caller

View File

@ -14,12 +14,12 @@ Extended attribute support for Go (linux + darwin + freebsd).
const path = "/tmp/myfile"
const prefix = "user."
if err := xattr.Setxattr(path, prefix+"test", []byte("test-attr-value")); err != nil {
if err := xattr.Set(path, prefix+"test", []byte("test-attr-value")); err != nil {
log.Fatal(err)
}
var data []byte
data, err = xattr.Getxattr(path, prefix+"test"); err != nil {
data, err = xattr.Get(path, prefix+"test"); err != nil {
log.Fatal(err)
}
```

View File

@ -1,21 +1,21 @@
/*
Package xattr provides support for extended attributes on linux, darwin and freebsd.
Extended attributes are name:value pairs associated permanently with files and directories,
similar to the environment strings associated with a process.
Extended attributes are name:value pairs associated permanently with files and directories,
similar to the environment strings associated with a process.
An attribute may be defined or undefined. If it is defined, its value may be empty or non-empty.
More details you can find here: https://en.wikipedia.org/wiki/Extended_file_attributes
*/
package xattr
// XAttrError records an error and the operation, file path and attribute that caused it.
type XAttrError struct {
// Error records an error and the operation, file path and attribute that caused it.
type Error struct {
Op string
Path string
Name string
Err error
}
func (e *XAttrError) Error() string {
func (e *Error) Error() string {
return e.Op + " " + e.Path + " " + e.Name + ": " + e.Err.Error()
}

View File

@ -2,32 +2,32 @@
package xattr
// Getxattr retrieves extended attribute data associated with path.
func Getxattr(path, name string) ([]byte, error) {
// Get retrieves extended attribute data associated with path.
func Get(path, name string) ([]byte, error) {
// find size.
size, err := getxattr(path, name, nil, 0, 0, 0)
if err != nil {
return nil, &XAttrError{"getxattr", path, name, err}
return nil, &Error{"xattr.Get", path, name, err}
}
if size > 0 {
buf := make([]byte, size)
// Read into buffer of that size.
read, err := getxattr(path, name, &buf[0], size, 0, 0)
if err != nil {
return nil, &XAttrError{"getxattr", path, name, err}
return nil, &Error{"xattr.Get", path, name, err}
}
return buf[:read], nil
}
return []byte{}, nil
}
// Listxattr retrieves a list of names of extended attributes associated
// List retrieves a list of names of extended attributes associated
// with the given path in the file system.
func Listxattr(path string) ([]string, error) {
func List(path string) ([]string, error) {
// find size.
size, err := listxattr(path, nil, 0, 0)
if err != nil {
return nil, &XAttrError{"listxattr", path, "", err}
return nil, &Error{"xattr.List", path, "", err}
}
if size > 0 {
@ -35,25 +35,25 @@ func Listxattr(path string) ([]string, error) {
// Read into buffer of that size.
read, err := listxattr(path, &buf[0], size, 0)
if err != nil {
return nil, &XAttrError{"listxattr", path, "", err}
return nil, &Error{"xattr.List", path, "", err}
}
return nullTermToStrings(buf[:read]), nil
}
return []string{}, nil
}
// Setxattr associates name and data together as an attribute of path.
func Setxattr(path, name string, data []byte) error {
// Set associates name and data together as an attribute of path.
func Set(path, name string, data []byte) error {
if err := setxattr(path, name, &data[0], len(data), 0, 0); err != nil {
return &XAttrError{"setxattr", path, name, err}
return &Error{"xattr.Set", path, name, err}
}
return nil
}
// Removexattr removes the attribute associated with the given path.
func Removexattr(path, name string) error {
// Remove removes the attribute associated with the given path.
func Remove(path, name string) error {
if err := removexattr(path, name, 0); err != nil {
return &XAttrError{"removexattr", path, name, err}
return &Error{"xattr.Remove", path, name, err}
}
return nil
}

View File

@ -10,61 +10,61 @@ const (
EXTATTR_NAMESPACE_USER = 1
)
// Getxattr retrieves extended attribute data associated with path.
func Getxattr(path, name string) ([]byte, error) {
// Get retrieves extended attribute data associated with path.
func Get(path, name string) ([]byte, error) {
// find size.
size, err := extattr_get_file(path, EXTATTR_NAMESPACE_USER, name, nil, 0)
if err != nil {
return nil, &XAttrError{"extattr_get_file", path, name, err}
return nil, &Error{"xattr.Get", path, name, err}
}
if size > 0 {
buf := make([]byte, size)
// Read into buffer of that size.
read, err := extattr_get_file(path, EXTATTR_NAMESPACE_USER, name, &buf[0], size)
if err != nil {
return nil, &XAttrError{"extattr_get_file", path, name, err}
return nil, &Error{"xattr.Get", path, name, err}
}
return buf[:read], nil
}
return []byte{}, nil
}
// Listxattr retrieves a list of names of extended attributes associated
// List retrieves a list of names of extended attributes associated
// with the given path in the file system.
func Listxattr(path string) ([]string, error) {
func List(path string) ([]string, error) {
// find size.
size, err := extattr_list_file(path, EXTATTR_NAMESPACE_USER, nil, 0)
if err != nil {
return nil, &XAttrError{"extattr_list_file", path, "", err}
return nil, &Error{"xattr.List", path, "", err}
}
if size > 0 {
buf := make([]byte, size)
// Read into buffer of that size.
read, err := extattr_list_file(path, EXTATTR_NAMESPACE_USER, &buf[0], size)
if err != nil {
return nil, &XAttrError{"extattr_list_file", path, "", err}
return nil, &Error{"xattr.List", path, "", err}
}
return attrListToStrings(buf[:read]), nil
}
return []string{}, nil
}
// Setxattr associates name and data together as an attribute of path.
func Setxattr(path, name string, data []byte) error {
// Set associates name and data together as an attribute of path.
func Set(path, name string, data []byte) error {
written, err := extattr_set_file(path, EXTATTR_NAMESPACE_USER, name, &data[0], len(data))
if err != nil {
return &XAttrError{"extattr_set_file", path, name, err}
return &Error{"xattr.Set", path, name, err}
}
if written != len(data) {
return &XAttrError{"extattr_set_file", path, name, syscall.E2BIG}
return &Error{"xattr.Set", path, name, syscall.E2BIG}
}
return nil
}
// Removexattr removes the attribute associated with the given path.
func Removexattr(path, name string) error {
// Remove removes the attribute associated with the given path.
func Remove(path, name string) error {
if err := extattr_delete_file(path, EXTATTR_NAMESPACE_USER, name); err != nil {
return &XAttrError{"extattr_delete_file", path, name, err}
return &Error{"xattr.Remove", path, name, err}
}
return nil
}

View File

@ -4,58 +4,58 @@ package xattr
import "syscall"
// Getxattr retrieves extended attribute data associated with path.
func Getxattr(path, name string) ([]byte, error) {
// Get retrieves extended attribute data associated with path.
func Get(path, name string) ([]byte, error) {
// find size.
size, err := syscall.Getxattr(path, name, nil)
if err != nil {
return nil, &XAttrError{"getxattr", path, name, err}
return nil, &Error{"xattr.Get", path, name, err}
}
if size > 0 {
data := make([]byte, size)
// Read into buffer of that size.
read, err := syscall.Getxattr(path, name, data)
if err != nil {
return nil, &XAttrError{"getxattr", path, name, err}
return nil, &Error{"xattr.Get", path, name, err}
}
return data[:read], nil
}
return []byte{}, nil
}
// Listxattr retrieves a list of names of extended attributes associated
// List retrieves a list of names of extended attributes associated
// with the given path in the file system.
func Listxattr(path string) ([]string, error) {
func List(path string) ([]string, error) {
// find size.
size, err := syscall.Listxattr(path, nil)
if err != nil {
return nil, &XAttrError{"listxattr", path, "", err}
return nil, &Error{"xattr.List", path, "", err}
}
if size > 0 {
buf := make([]byte, size)
// Read into buffer of that size.
read, err := syscall.Listxattr(path, buf)
if err != nil {
return nil, &XAttrError{"listxattr", path, "", err}
return nil, &Error{"xattr.List", path, "", err}
}
return nullTermToStrings(buf[:read]), nil
}
return []string{}, nil
}
// Setxattr associates name and data together as an attribute of path.
func Setxattr(path, name string, data []byte) error {
// Set associates name and data together as an attribute of path.
func Set(path, name string, data []byte) error {
if err := syscall.Setxattr(path, name, data, 0); err != nil {
return &XAttrError{"setxattr", path, name, err}
return &Error{"xattr.Set", path, name, err}
}
return nil
}
// Removexattr removes the attribute associated
// Remove removes the attribute associated
// with the given path.
func Removexattr(path, name string) error {
func Remove(path, name string) error {
if err := syscall.Removexattr(path, name); err != nil {
return &XAttrError{"removexattr", path, name, err}
return &Error{"xattr.Remove", path, name, err}
}
return nil
}

View File

@ -18,12 +18,12 @@ func Test_setxattr(t *testing.T) {
}
defer os.Remove(tmp.Name())
err = Setxattr(tmp.Name(), UserPrefix+"test", []byte("test-attr-value"))
err = Set(tmp.Name(), UserPrefix+"test", []byte("test-attr-value"))
if err != nil {
t.Fatal(err)
}
list, err := Listxattr(tmp.Name())
list, err := List(tmp.Name())
if err != nil {
t.Fatal(err)
}
@ -40,7 +40,7 @@ func Test_setxattr(t *testing.T) {
}
var data []byte
data, err = Getxattr(tmp.Name(), UserPrefix+"test")
data, err = Get(tmp.Name(), UserPrefix+"test")
if err != nil {
t.Fatal(err)
}
@ -50,7 +50,7 @@ func Test_setxattr(t *testing.T) {
t.Fail()
}
err = Removexattr(tmp.Name(), UserPrefix+"test")
err = Remove(tmp.Name(), UserPrefix+"test")
if err != nil {
t.Fatal(err)
}

View File

@ -85,8 +85,7 @@ type Chunker struct {
chunkerState
}
// New returns a new Chunker based on polynomial p that reads from rd
// with bufsize and pass all data to hash along the way.
// New returns a new Chunker based on polynomial p that reads from rd.
func New(rd io.Reader, pol Pol) *Chunker {
c := &Chunker{
chunkerState: chunkerState{
@ -141,8 +140,8 @@ func (c *Chunker) reset() {
c.pre = c.MinSize - windowSize
}
// Calculate out_table and mod_table for optimization. Must be called only
// once. This implementation uses a cache in the global variable cache.
// fillTables calculates out_table and mod_table for optimization. This
// implementation uses a cache in the global variable cache.
func (c *Chunker) fillTables() {
// if polynomial hasn't been specified, do not compute anything for now
if c.pol == 0 {

View File

@ -8,6 +8,7 @@ Many of the most widely used Go projects are built using Cobra including:
* [Hugo](http://gohugo.io)
* [rkt](https://github.com/coreos/rkt)
* [etcd](https://github.com/coreos/etcd)
* [Docker](https://github.com/docker/docker)
* [Docker (distribution)](https://github.com/docker/distribution)
* [OpenShift](https://www.openshift.com/)
* [Delve](https://github.com/derekparker/delve)
@ -157,12 +158,17 @@ In a Cobra app, typically the main.go file is very bare. It serves, one purpose,
```go
package main
import "{pathToYourApp}/cmd"
import (
"fmt"
"os"
"{pathToYourApp}/cmd"
)
func main() {
if err := cmd.RootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(-1)
os.Exit(1)
}
}
```
@ -313,12 +319,17 @@ In a Cobra app, typically the main.go file is very bare. It serves, one purpose,
```go
package main
import "{pathToYourApp}/cmd"
import (
"fmt"
"os"
"{pathToYourApp}/cmd"
)
func main() {
if err := cmd.RootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(-1)
os.Exit(1)
}
}
```
@ -337,6 +348,7 @@ package cmd
import (
"github.com/spf13/cobra"
"fmt"
)
func init() {
@ -744,7 +756,7 @@ providing a way to handle the errors in one location. The current list of functi
* PersistentPostRunE
If you would like to silence the default `error` and `usage` output in favor of your own, you can set `SilenceUsage`
and `SilenceErrors` to `false` on the command. A child command respects these flags if they are set on the parent
and `SilenceErrors` to `true` on the command. A child command respects these flags if they are set on the parent
command.
**Example Usage using RunE:**

View File

@ -10,6 +10,7 @@ import (
"github.com/spf13/pflag"
)
// Annotations for Bash completion.
const (
BashCompFilenameExt = "cobra_annotation_bash_completion_filename_extensions"
BashCompCustom = "cobra_annotation_bash_completion_custom"
@ -22,7 +23,7 @@ func preamble(out io.Writer, name string) error {
if err != nil {
return err
}
_, err = fmt.Fprint(out, `
preamStr := `
__debug()
{
if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then
@ -87,8 +88,8 @@ __handle_reply()
local index flag
flag="${cur%%=*}"
__index_of_word "${flag}" "${flags_with_completion[@]}"
COMPREPLY=()
if [[ ${index} -ge 0 ]]; then
COMPREPLY=()
PREFIX=""
cur="${cur#*=}"
${flags_completion[${index}]}
@ -224,7 +225,7 @@ __handle_command()
fi
c=$((c+1))
__debug "${FUNCNAME[0]}: looking for ${next_command}"
declare -F $next_command >/dev/null && $next_command
declare -F "$next_command" >/dev/null && $next_command
}
__handle_word()
@ -246,7 +247,8 @@ __handle_word()
__handle_word
}
`)
`
_, err = fmt.Fprint(out, preamStr)
return err
}
@ -566,6 +568,7 @@ func gen(cmd *Command, w io.Writer) error {
return nil
}
// GenBashCompletion generates bash completion file and writes to the passed writer.
func (cmd *Command) GenBashCompletion(w io.Writer) error {
if err := preamble(w, cmd.Name()); err != nil {
return err
@ -585,6 +588,7 @@ func nonCompletableFlag(flag *pflag.Flag) bool {
return flag.Hidden || len(flag.Deprecated) > 0
}
// GenBashCompletionFile generates bash completion file.
func (cmd *Command) GenBashCompletionFile(filename string) error {
outFile, err := os.Create(filename)
if err != nil {

View File

@ -18,7 +18,7 @@ func main() {
}
```
That will get you completions of subcommands and flags. If you make additional annotations to your code, you can get even more intelligent and flexible behavior.
`out.sh` will get you completions of subcommands and flags. Copy it to `/etc/bash_completion.d/` as described [here](https://debian-administration.org/article/316/An_introduction_to_bash_completion_part_1) and reset your terminal to use autocompletion. If you make additional annotations to your code, you can get even more intelligent and flexible behavior.
## Creating your own custom functions

View File

@ -37,7 +37,8 @@ var templateFuncs = template.FuncMap{
var initializers []func()
// Automatic prefix matching can be a dangerous thing to automatically enable in CLI tools.
// EnablePrefixMatching allows to set automatic prefix matching. Automatic prefix matching can be a dangerous thing
// to automatically enable in CLI tools.
// Set this to true to enable it.
var EnablePrefixMatching = false
@ -51,7 +52,7 @@ func AddTemplateFunc(name string, tmplFunc interface{}) {
templateFuncs[name] = tmplFunc
}
// AddTemplateFuncs adds multiple template functions availalble to Usage and
// AddTemplateFuncs adds multiple template functions that are available to Usage and
// Help template generation.
func AddTemplateFuncs(tmplFuncs template.FuncMap) {
for k, v := range tmplFuncs {

View File

@ -612,7 +612,7 @@ func TestSubcommandExecuteC(t *testing.T) {
Use: "echo message",
Run: func(c *Command, args []string) {
msg := strings.Join(args, " ")
c.Println(msg, msg)
c.Println(msg)
},
}

View File

@ -57,6 +57,9 @@ type Command struct {
Deprecated string
// Is this command hidden and should NOT show up in the list of available commands?
Hidden bool
// Annotations are key/value pairs that can be used by applications to identify or
// group commands
Annotations map[string]string
// Full set of flags
flags *flag.FlagSet
// Set of flags childrens of this command will inherit
@ -109,10 +112,11 @@ type Command struct {
flagErrorBuf *bytes.Buffer
args []string // actual args parsed from flags
output *io.Writer // out writer if set in SetOutput(w)
usageFunc func(*Command) error // Usage can be defined by application
usageTemplate string // Can be defined by Application
args []string // actual args parsed from flags
output io.Writer // out writer if set in SetOutput(w)
usageFunc func(*Command) error // Usage can be defined by application
usageTemplate string // Can be defined by Application
flagErrorFunc func(*Command, error) error
helpTemplate string // Can be defined by Application
helpFunc func(*Command, []string) // Help can be defined by application
helpCommand *Command // The help command
@ -128,7 +132,7 @@ type Command struct {
DisableFlagParsing bool
}
// os.Args[1:] by default, if desired, can be overridden
// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden
// particularly useful when testing.
func (c *Command) SetArgs(a []string) {
c.args = a
@ -137,29 +141,36 @@ func (c *Command) SetArgs(a []string) {
// SetOutput sets the destination for usage and error messages.
// If output is nil, os.Stderr is used.
func (c *Command) SetOutput(output io.Writer) {
c.output = &output
c.output = output
}
// Usage can be defined by application.
// SetUsageFunc sets usage function. Usage can be defined by application.
func (c *Command) SetUsageFunc(f func(*Command) error) {
c.usageFunc = f
}
// Can be defined by Application.
// SetUsageTemplate sets usage template. Can be defined by Application.
func (c *Command) SetUsageTemplate(s string) {
c.usageTemplate = s
}
// Can be defined by Application.
// SetFlagErrorFunc sets a function to generate an error when flag parsing
// fails.
func (c *Command) SetFlagErrorFunc(f func(*Command, error) error) {
c.flagErrorFunc = f
}
// SetHelpFunc sets help function. Can be defined by Application.
func (c *Command) SetHelpFunc(f func(*Command, []string)) {
c.helpFunc = f
}
// SetHelpCommand sets help command.
func (c *Command) SetHelpCommand(cmd *Command) {
c.helpCommand = cmd
}
// Can be defined by Application.
// SetHelpTemplate sets help template to be used. Application can use it to set custom template.
func (c *Command) SetHelpTemplate(s string) {
c.helpTemplate = s
}
@ -176,17 +187,19 @@ func (c *Command) SetGlobalNormalizationFunc(n func(f *flag.FlagSet, name string
}
}
// OutOrStdout returns output to stdout.
func (c *Command) OutOrStdout() io.Writer {
return c.getOut(os.Stdout)
}
// OutOrStderr returns output to stderr
func (c *Command) OutOrStderr() io.Writer {
return c.getOut(os.Stderr)
}
func (c *Command) getOut(def io.Writer) io.Writer {
if c.output != nil {
return *c.output
return c.output
}
if c.HasParent() {
return c.parent.getOut(def)
@ -224,12 +237,8 @@ func (c *Command) Usage() error {
// HelpFunc returns either the function set by SetHelpFunc for this command
// or a parent, or it returns a function with default help behavior.
func (c *Command) HelpFunc() func(*Command, []string) {
cmd := c
for cmd != nil {
if cmd.helpFunc != nil {
return cmd.helpFunc
}
cmd = cmd.parent
if helpFunc := c.checkHelpFunc(); helpFunc != nil {
return helpFunc
}
return func(*Command, []string) {
c.mergePersistentFlags()
@ -240,6 +249,20 @@ func (c *Command) HelpFunc() func(*Command, []string) {
}
}
// checkHelpFunc checks if there is helpFunc in ancestors of c.
func (c *Command) checkHelpFunc() func(*Command, []string) {
if c == nil {
return nil
}
if c.helpFunc != nil {
return c.helpFunc
}
if c.HasParent() {
return c.parent.checkHelpFunc()
}
return nil
}
// Help puts out the help for the command.
// Used when a user calls help [command].
// Can be defined by user by overriding HelpFunc.
@ -248,6 +271,7 @@ func (c *Command) Help() error {
return nil
}
// UsageString return usage string.
func (c *Command) UsageString() string {
tmpOutput := c.output
bb := new(bytes.Buffer)
@ -257,8 +281,25 @@ func (c *Command) UsageString() string {
return bb.String()
}
// FlagErrorFunc returns either the function set by SetFlagErrorFunc for this
// command or a parent, or it returns a function which returns the original
// error.
func (c *Command) FlagErrorFunc() (f func(*Command, error) error) {
if c.flagErrorFunc != nil {
return c.flagErrorFunc
}
if c.HasParent() {
return c.parent.FlagErrorFunc()
}
return func(c *Command, err error) error {
return err
}
}
var minUsagePadding = 25
// UsagePadding return padding for the usage.
func (c *Command) UsagePadding() int {
if c.parent == nil || minUsagePadding > c.parent.commandsMaxUseLen {
return minUsagePadding
@ -268,7 +309,7 @@ func (c *Command) UsagePadding() int {
var minCommandPathPadding = 11
//
// CommandPathPadding return padding for the command path.
func (c *Command) CommandPathPadding() int {
if c.parent == nil || minCommandPathPadding > c.parent.commandsMaxCommandPathLen {
return minCommandPathPadding
@ -278,6 +319,7 @@ func (c *Command) CommandPathPadding() int {
var minNamePadding = 11
// NamePadding returns padding for the name.
func (c *Command) NamePadding() int {
if c.parent == nil || minNamePadding > c.parent.commandsMaxNameLen {
return minNamePadding
@ -285,6 +327,7 @@ func (c *Command) NamePadding() int {
return c.parent.commandsMaxNameLen
}
// UsageTemplate returns usage template for the command.
func (c *Command) UsageTemplate() string {
if c.usageTemplate != "" {
return c.usageTemplate
@ -298,28 +341,28 @@ func (c *Command) UsageTemplate() string {
{{ .CommandPath}} [command]{{end}}{{if gt .Aliases 0}}
Aliases:
{{.NameAndAliases}}
{{end}}{{if .HasExample}}
{{.NameAndAliases}}{{end}}{{if .HasExample}}
Examples:
{{ .Example }}{{end}}{{ if .HasAvailableSubCommands}}
{{ .Example }}{{end}}{{if .HasAvailableSubCommands}}
Available Commands:{{range .Commands}}{{if .IsAvailableCommand}}
{{rpad .Name .NamePadding }} {{.Short}}{{end}}{{end}}{{end}}{{ if .HasAvailableLocalFlags}}
Available Commands:{{range .Commands}}{{if (or .IsAvailableCommand (eq .Name "help"))}}
{{rpad .Name .NamePadding }} {{.Short}}{{end}}{{end}}{{end}}{{if .HasAvailableLocalFlags}}
Flags:
{{.LocalFlags.FlagUsages | trimRightSpace}}{{end}}{{ if .HasAvailableInheritedFlags}}
{{.LocalFlags.FlagUsages | trimRightSpace}}{{end}}{{if .HasAvailableInheritedFlags}}
Global Flags:
{{.InheritedFlags.FlagUsages | trimRightSpace}}{{end}}{{if .HasHelpSubCommands}}
Additional help topics:{{range .Commands}}{{if .IsHelpCommand}}
{{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{ if .HasAvailableSubCommands }}
Additional help topics:{{range .Commands}}{{if .IsAdditionalHelpTopicCommand}}
{{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{if .HasAvailableSubCommands}}
Use "{{.CommandPath}} [command] --help" for more information about a command.{{end}}
`
}
// HelpTemplate return help template for the command.
func (c *Command) HelpTemplate() string {
if c.helpTemplate != "" {
return c.helpTemplate
@ -340,20 +383,18 @@ func (c *Command) resetChildrensParents() {
}
}
// Test if the named flag is a boolean flag.
func isBooleanFlag(name string, f *flag.FlagSet) bool {
func hasNoOptDefVal(name string, f *flag.FlagSet) bool {
flag := f.Lookup(name)
if flag == nil {
return false
}
return flag.Value.Type() == "bool"
return len(flag.NoOptDefVal) > 0
}
// Test if the named flag is a boolean flag.
func isBooleanShortFlag(name string, f *flag.FlagSet) bool {
func shortHasNoOptDefVal(name string, fs *flag.FlagSet) bool {
result := false
f.VisitAll(func(f *flag.Flag) {
if f.Shorthand == name && f.Value.Type() == "bool" {
fs.VisitAll(func(flag *flag.Flag) {
if flag.Shorthand == name && len(flag.NoOptDefVal) > 0 {
result = true
}
})
@ -379,13 +420,13 @@ func stripFlags(args []string, c *Command) []string {
inQuote = true
case strings.HasPrefix(y, "--") && !strings.Contains(y, "="):
// TODO: this isn't quite right, we should really check ahead for 'true' or 'false'
inFlag = !isBooleanFlag(y[2:], c.Flags())
case strings.HasPrefix(y, "-") && !strings.Contains(y, "=") && len(y) == 2 && !isBooleanShortFlag(y[1:], c.Flags()):
inFlag = !hasNoOptDefVal(y[2:], c.Flags())
case strings.HasPrefix(y, "-") && !strings.Contains(y, "=") && len(y) == 2 && !shortHasNoOptDefVal(y[1:], c.Flags()):
inFlag = true
case inFlag:
inFlag = false
case y == "":
// strip empty commands, as the go tests expect this to be ok....
// strip empty commands, as the go tests expect this to be ok....
case !strings.HasPrefix(y, "-"):
commands = append(commands, y)
inFlag = false
@ -414,7 +455,7 @@ func argsMinusFirstX(args []string, x string) []string {
return args
}
// find the target command given the args and command tree
// Find the target command given the args and command tree
// Meant to be run on the highest node. Only searches down.
func (c *Command) Find(args []string) (*Command, []string, error) {
if c == nil {
@ -482,6 +523,7 @@ func (c *Command) Find(args []string) (*Command, []string, error) {
return commandFound, a, nil
}
// SuggestionsFor provides suggestions for the typedName.
func (c *Command) SuggestionsFor(typedName string) []string {
suggestions := []string{}
for _, cmd := range c.commands {
@ -502,6 +544,7 @@ func (c *Command) SuggestionsFor(typedName string) []string {
return suggestions
}
// VisitParents visits all parents of the command and invokes fn on each parent.
func (c *Command) VisitParents(fn func(*Command)) {
var traverse func(*Command) *Command
@ -517,6 +560,7 @@ func (c *Command) VisitParents(fn func(*Command)) {
traverse(c)
}
// Root finds root command.
func (c *Command) Root() *Command {
var findRoot func(*Command) *Command
@ -553,7 +597,7 @@ func (c *Command) execute(a []string) (err error) {
err = c.ParseFlags(a)
if err != nil {
return err
return c.FlagErrorFunc()(c, err)
}
// If help is called, regardless of other flags, return we want help
// Also say we need help if the command isn't runnable.
@ -641,7 +685,7 @@ func (c *Command) errorMsgFromParse() string {
return ""
}
// Call execute to use the args (os.Args[1:] by default)
// Execute Call execute to use the args (os.Args[1:] by default)
// and run through the command tree finding appropriate matches
// for commands and then corresponding flags.
func (c *Command) Execute() error {
@ -649,8 +693,8 @@ func (c *Command) Execute() error {
return err
}
// ExecuteC executes the command.
func (c *Command) ExecuteC() (cmd *Command, err error) {
// Regardless of what command execute is called on, run on Root only
if c.HasParent() {
return c.Root().ExecuteC()
@ -712,6 +756,7 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
}
func (c *Command) initHelpFlag() {
c.mergePersistentFlags()
if c.Flags().Lookup("help") == nil {
c.Flags().BoolP("help", "h", false, "help for "+c.Name())
}
@ -734,7 +779,7 @@ func (c *Command) initHelpCmd() {
Run: func(c *Command, args []string) {
cmd, _, e := c.Root().Find(args)
if cmd == nil || e != nil {
c.Printf("Unknown help topic %#q.", args)
c.Printf("Unknown help topic %#q\n", args)
c.Root().Usage()
} else {
cmd.Help()
@ -742,10 +787,11 @@ func (c *Command) initHelpCmd() {
},
}
}
c.RemoveCommand(c.helpCommand)
c.AddCommand(c.helpCommand)
}
// Used for testing.
// ResetCommands used for testing.
func (c *Command) ResetCommands() {
c.commands = nil
c.helpCommand = nil
@ -868,7 +914,7 @@ func (c *Command) UseLine() string {
return str + c.Use
}
// For use in determining which flags have been assigned to which commands
// DebugFlags used to determine which flags have been assigned to which commands
// and which persist.
func (c *Command) DebugFlags() {
c.Println("DebugFlags called on", c.Name())
@ -923,7 +969,8 @@ func (c *Command) Name() string {
if i >= 0 {
name = name[:i]
}
return name
c.name = name
return c.name
}
// HasAlias determines if a given string is an alias of the command.
@ -936,10 +983,12 @@ func (c *Command) HasAlias(s string) bool {
return false
}
// NameAndAliases returns string containing name and all aliases
func (c *Command) NameAndAliases() string {
return strings.Join(append([]string{c.Name()}, c.Aliases...), ", ")
}
// HasExample determines if the command has example.
func (c *Command) HasExample() bool {
return len(c.Example) > 0
}
@ -972,11 +1021,12 @@ func (c *Command) IsAvailableCommand() bool {
return false
}
// IsHelpCommand determines if a command is a 'help' command; a help command is
// determined by the fact that it is NOT runnable/hidden/deprecated, and has no
// sub commands that are runnable/hidden/deprecated.
func (c *Command) IsHelpCommand() bool {
// IsAdditionalHelpTopicCommand determines if a command is an additional
// help topic command; additional help topic command is determined by the
// fact that it is NOT runnable/hidden/deprecated, and has no sub commands that
// are runnable/hidden/deprecated.
// Concrete example: https://github.com/spf13/cobra/issues/393#issuecomment-282741924.
func (c *Command) IsAdditionalHelpTopicCommand() bool {
// if a command is runnable, deprecated, or hidden it is not a 'help' command
if c.Runnable() || len(c.Deprecated) != 0 || c.Hidden {
return false
@ -984,7 +1034,7 @@ func (c *Command) IsHelpCommand() bool {
// if any non-help sub commands are found, the command is not a 'help' command
for _, sub := range c.commands {
if !sub.IsHelpCommand() {
if !sub.IsAdditionalHelpTopicCommand() {
return false
}
}
@ -997,10 +1047,9 @@ func (c *Command) IsHelpCommand() bool {
// that need to be shown in the usage/help default template under 'additional help
// topics'.
func (c *Command) HasHelpSubCommands() bool {
// return true on the first found available 'help' sub command
for _, sub := range c.commands {
if sub.IsHelpCommand() {
if sub.IsAdditionalHelpTopicCommand() {
return true
}
}
@ -1012,7 +1061,6 @@ func (c *Command) HasHelpSubCommands() bool {
// HasAvailableSubCommands determines if a command has available sub commands that
// need to be shown in the usage/help default template under 'available commands'.
func (c *Command) HasAvailableSubCommands() bool {
// return true on the first found available (non deprecated/help/hidden)
// sub command
for _, sub := range c.commands {
@ -1036,7 +1084,7 @@ func (c *Command) GlobalNormalizationFunc() func(f *flag.FlagSet, name string) f
return c.globNormFunc
}
// Flage returns the complete FlagSet that applies
// Flags returns the complete FlagSet that applies
// to this command (local and persistent declared here and by all parents).
func (c *Command) Flags() *flag.FlagSet {
if c.flags == nil {
@ -1136,44 +1184,44 @@ func (c *Command) ResetFlags() {
c.pflags.SetOutput(c.flagErrorBuf)
}
// Does the command contain any flags (local plus persistent from the entire structure).
// HasFlags checks if the command contains any flags (local plus persistent from the entire structure).
func (c *Command) HasFlags() bool {
return c.Flags().HasFlags()
}
// Does the command contain persistent flags.
// HasPersistentFlags checks if the command contains persistent flags.
func (c *Command) HasPersistentFlags() bool {
return c.PersistentFlags().HasFlags()
}
// Does the command has flags specifically declared locally.
// HasLocalFlags checks if the command has flags specifically declared locally.
func (c *Command) HasLocalFlags() bool {
return c.LocalFlags().HasFlags()
}
// Does the command have flags inherited from its parent command.
// HasInheritedFlags checks if the command has flags inherited from its parent command.
func (c *Command) HasInheritedFlags() bool {
return c.InheritedFlags().HasFlags()
}
// Does the command contain any flags (local plus persistent from the entire
// HasAvailableFlags checks if the command contains any flags (local plus persistent from the entire
// structure) which are not hidden or deprecated.
func (c *Command) HasAvailableFlags() bool {
return c.Flags().HasAvailableFlags()
}
// Does the command contain persistent flags which are not hidden or deprecated.
// HasAvailablePersistentFlags checks if the command contains persistent flags which are not hidden or deprecated.
func (c *Command) HasAvailablePersistentFlags() bool {
return c.PersistentFlags().HasAvailableFlags()
}
// Does the command has flags specifically declared locally which are not hidden
// HasAvailableLocalFlags checks if the command has flags specifically declared locally which are not hidden
// or deprecated.
func (c *Command) HasAvailableLocalFlags() bool {
return c.LocalFlags().HasAvailableFlags()
}
// Does the command have flags inherited from its parent command which are
// HasAvailableInheritedFlags checks if the command has flags inherited from its parent command which are
// not hidden or deprecated.
func (c *Command) HasAvailableInheritedFlags() bool {
return c.InheritedFlags().HasAvailableFlags()

View File

@ -1,6 +1,8 @@
package cobra
import (
"bytes"
"fmt"
"os"
"reflect"
"testing"
@ -134,6 +136,20 @@ func Test_DisableFlagParsing(t *testing.T) {
}
}
func TestInitHelpFlagMergesFlags(t *testing.T) {
usage := "custom flag"
baseCmd := Command{Use: "testcmd"}
baseCmd.PersistentFlags().Bool("help", false, usage)
cmd := Command{Use: "do"}
baseCmd.AddCommand(&cmd)
cmd.initHelpFlag()
actual := cmd.Flags().Lookup("help").Usage
if actual != usage {
t.Fatalf("Expected the help flag from the base command with usage '%s', but got the default with usage '%s'", usage, actual)
}
}
func TestCommandsAreSorted(t *testing.T) {
EnableCommandSorting = true
@ -174,3 +190,35 @@ func TestEnableCommandSortingIsDisabled(t *testing.T) {
EnableCommandSorting = true
}
func TestSetOutput(t *testing.T) {
cmd := &Command{}
cmd.SetOutput(nil)
if out := cmd.OutOrStdout(); out != os.Stdout {
t.Fatalf("expected setting output to nil to revert back to stdout, got %v", out)
}
}
func TestFlagErrorFunc(t *testing.T) {
cmd := &Command{
Use: "print",
RunE: func(cmd *Command, args []string) error {
return nil
},
}
expectedFmt := "This is expected: %s"
cmd.SetFlagErrorFunc(func(c *Command, err error) error {
return fmt.Errorf(expectedFmt, err)
})
cmd.SetArgs([]string{"--bogus-flag"})
cmd.SetOutput(new(bytes.Buffer))
err := cmd.Execute()
expected := fmt.Sprintf(expectedFmt, "unknown flag: --bogus-flag")
if err.Error() != expected {
t.Errorf("expected %v, got %v", expected, err.Error())
}
}

View File

@ -37,7 +37,7 @@ func GenManTree(cmd *cobra.Command, header *GenManHeader, dir string) error {
return GenManTreeFromOpts(cmd, GenManTreeOptions{
Header: header,
Path: dir,
CommandSeparator: "_",
CommandSeparator: "-",
})
}
@ -49,7 +49,7 @@ func GenManTreeFromOpts(cmd *cobra.Command, opts GenManTreeOptions) error {
header = &GenManHeader{}
}
for _, c := range cmd.Commands() {
if !c.IsAvailableCommand() || c.IsHelpCommand() {
if !c.IsAvailableCommand() || c.IsAdditionalHelpTopicCommand() {
continue
}
if err := GenManTreeFromOpts(c, opts); err != nil {
@ -216,7 +216,7 @@ func genMan(cmd *cobra.Command, header *GenManHeader) []byte {
children := cmd.Commands()
sort.Sort(byName(children))
for _, c := range children {
if !c.IsAvailableCommand() || c.IsHelpCommand() {
if !c.IsAvailableCommand() || c.IsAdditionalHelpTopicCommand() {
continue
}
seealso := fmt.Sprintf("**%s-%s(%s)**", dashCommandName, c.Name(), header.Section)

View File

@ -15,7 +15,7 @@ func main() {
Use: "test",
Short: "my test program",
}
header := &cobra.GenManHeader{
header := &doc.GenManHeader{
Title: "MINE",
Section: "3",
}
@ -23,4 +23,4 @@ func main() {
}
```
That will get you a man page `/tmp/test.1`
That will get you a man page `/tmp/test.3`

View File

@ -8,7 +8,7 @@ import (
"github.com/spf13/cobra/doc"
)
func ExampleCommand_GenManTree() {
func ExampleGenManTree() {
cmd := &cobra.Command{
Use: "test",
Short: "my test program",
@ -20,7 +20,7 @@ func ExampleCommand_GenManTree() {
doc.GenManTree(cmd, header, "/tmp")
}
func ExampleCommand_GenMan() {
func ExampleGenMan() {
cmd := &cobra.Command{
Use: "test",
Short: "my test program",

View File

@ -119,7 +119,7 @@ func GenMarkdownCustom(cmd *cobra.Command, w io.Writer, linkHandler func(string)
sort.Sort(byName(children))
for _, child := range children {
if !child.IsAvailableCommand() || child.IsHelpCommand() {
if !child.IsAvailableCommand() || child.IsAdditionalHelpTopicCommand() {
continue
}
cname := name + " " + child.Name()
@ -149,7 +149,7 @@ func GenMarkdownTree(cmd *cobra.Command, dir string) error {
func GenMarkdownTreeCustom(cmd *cobra.Command, dir string, filePrepender, linkHandler func(string) string) error {
for _, c := range cmd.Commands() {
if !c.IsAvailableCommand() || c.IsHelpCommand() {
if !c.IsAvailableCommand() || c.IsAdditionalHelpTopicCommand() {
continue
}
if err := GenMarkdownTreeCustom(c, dir, filePrepender, linkHandler); err != nil {

View File

@ -32,15 +32,15 @@ import (
"io/ioutil"
"os"
kubectlcmd "k8s.io/kubernetes/pkg/kubectl/cmd"
"k8s.io/kubernetes/pkg/kubectl/cmd"
cmdutil "k8s.io/kubernetes/pkg/kubectl/cmd/util"
"github.com/spf13/cobra/doc"
)
func main() {
cmd := kubectlcmd.NewKubectlCommand(cmdutil.NewFactory(nil), os.Stdin, ioutil.Discard, ioutil.Discard)
doc.GenMarkdownTree(cmd, "./")
kubectl := cmd.NewKubectlCommand(cmdutil.NewFactory(nil), os.Stdin, ioutil.Discard, ioutil.Discard)
doc.GenMarkdownTree(kubectl, "./")
}
```
@ -101,4 +101,3 @@ linkHandler := func(name string) string {
return "/commands/" + strings.ToLower(base) + "/"
}
```

View File

@ -13,7 +13,11 @@
package doc
import "github.com/spf13/cobra"
import (
"strings"
"github.com/spf13/cobra"
)
// Test to see if we have a reason to print See Also information in docs
// Basically this is a test for a parent commend or a subcommand which is
@ -23,7 +27,7 @@ func hasSeeAlso(cmd *cobra.Command) bool {
return true
}
for _, c := range cmd.Commands() {
if !c.IsAvailableCommand() || c.IsHelpCommand() {
if !c.IsAvailableCommand() || c.IsAdditionalHelpTopicCommand() {
continue
}
return true
@ -31,6 +35,15 @@ func hasSeeAlso(cmd *cobra.Command) bool {
return false
}
// Temporary workaround for yaml lib generating incorrect yaml with long strings
// that do not contain \n.
func forceMultiLine(s string) string {
if len(s) > 60 && !strings.Contains(s, "\n") {
s = s + "\n"
}
return s
}
type byName []*cobra.Command
func (s byName) Len() int { return len(s) }

View File

@ -0,0 +1,165 @@
// Copyright 2016 French Ben. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package doc
import (
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strings"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"gopkg.in/yaml.v2"
)
type cmdOption struct {
Name string
Shorthand string `yaml:",omitempty"`
DefaultValue string `yaml:"default_value,omitempty"`
Usage string `yaml:",omitempty"`
}
type cmdDoc struct {
Name string
Synopsis string `yaml:",omitempty"`
Description string `yaml:",omitempty"`
Options []cmdOption `yaml:",omitempty"`
InheritedOptions []cmdOption `yaml:"inherited_options,omitempty"`
Example string `yaml:",omitempty"`
SeeAlso []string `yaml:"see_also,omitempty"`
}
// GenYamlTree creates yaml structured ref files for this command and all descendants
// in the directory given. This function may not work
// correctly if your command names have - in them. If you have `cmd` with two
// subcmds, `sub` and `sub-third`. And `sub` has a subcommand called `third`
// it is undefined which help output will be in the file `cmd-sub-third.1`.
func GenYamlTree(cmd *cobra.Command, dir string) error {
identity := func(s string) string { return s }
emptyStr := func(s string) string { return "" }
return GenYamlTreeCustom(cmd, dir, emptyStr, identity)
}
// GenYamlTreeCustom creates yaml structured ref files
func GenYamlTreeCustom(cmd *cobra.Command, dir string, filePrepender, linkHandler func(string) string) error {
for _, c := range cmd.Commands() {
if !c.IsAvailableCommand() || c.IsAdditionalHelpTopicCommand() {
continue
}
if err := GenYamlTreeCustom(c, dir, filePrepender, linkHandler); err != nil {
return err
}
}
basename := strings.Replace(cmd.CommandPath(), " ", "_", -1) + ".yaml"
filename := filepath.Join(dir, basename)
f, err := os.Create(filename)
if err != nil {
return err
}
defer f.Close()
if _, err := io.WriteString(f, filePrepender(filename)); err != nil {
return err
}
if err := GenYamlCustom(cmd, f, linkHandler); err != nil {
return err
}
return nil
}
// GenYaml creates yaml output
func GenYaml(cmd *cobra.Command, w io.Writer) error {
return GenYamlCustom(cmd, w, func(s string) string { return s })
}
// GenYamlCustom creates custom yaml output
func GenYamlCustom(cmd *cobra.Command, w io.Writer, linkHandler func(string) string) error {
yamlDoc := cmdDoc{}
yamlDoc.Name = cmd.CommandPath()
yamlDoc.Synopsis = forceMultiLine(cmd.Short)
yamlDoc.Description = forceMultiLine(cmd.Long)
if len(cmd.Example) > 0 {
yamlDoc.Example = cmd.Example
}
flags := cmd.NonInheritedFlags()
if flags.HasFlags() {
yamlDoc.Options = genFlagResult(flags)
}
flags = cmd.InheritedFlags()
if flags.HasFlags() {
yamlDoc.InheritedOptions = genFlagResult(flags)
}
if hasSeeAlso(cmd) {
result := []string{}
if cmd.HasParent() {
parent := cmd.Parent()
result = append(result, parent.CommandPath()+" - "+parent.Short)
}
children := cmd.Commands()
sort.Sort(byName(children))
for _, child := range children {
if !child.IsAvailableCommand() || child.IsAdditionalHelpTopicCommand() {
continue
}
result = append(result, child.Name()+" - "+child.Short)
}
yamlDoc.SeeAlso = result
}
final, err := yaml.Marshal(&yamlDoc)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
if _, err := fmt.Fprintf(w, string(final)); err != nil {
return err
}
return nil
}
func genFlagResult(flags *pflag.FlagSet) []cmdOption {
var result []cmdOption
flags.VisitAll(func(flag *pflag.Flag) {
// Todo, when we mark a shorthand is deprecated, but specify an empty message.
// The flag.ShorthandDeprecated is empty as the shorthand is deprecated.
// Using len(flag.ShorthandDeprecated) > 0 can't handle this, others are ok.
if !(len(flag.ShorthandDeprecated) > 0) && len(flag.Shorthand) > 0 {
opt := cmdOption{
flag.Name,
flag.Shorthand,
flag.DefValue,
forceMultiLine(flag.Usage),
}
result = append(result, opt)
} else {
opt := cmdOption{
Name: flag.Name,
DefaultValue: forceMultiLine(flag.DefValue),
Usage: forceMultiLine(flag.Usage),
}
result = append(result, opt)
}
})
return result
}

View File

@ -0,0 +1,103 @@
# Generating Yaml Docs For Your Own cobra.Command
Generating yaml files from a cobra command is incredibly easy. An example is as follows:
```go
package main
import (
"github.com/spf13/cobra"
"github.com/spf13/cobra/doc"
)
func main() {
cmd := &cobra.Command{
Use: "test",
Short: "my test program",
}
doc.GenYamlTree(cmd, "/tmp")
}
```
That will get you a Yaml document `/tmp/test.yaml`
## Generate yaml docs for the entire command tree
This program can actually generate docs for the kubectl command in the kubernetes project
```go
package main
import (
"io/ioutil"
"os"
"k8s.io/kubernetes/pkg/kubectl/cmd"
cmdutil "k8s.io/kubernetes/pkg/kubectl/cmd/util"
"github.com/spf13/cobra/doc"
)
func main() {
kubectl := cmd.NewKubectlCommand(cmdutil.NewFactory(nil), os.Stdin, ioutil.Discard, ioutil.Discard)
doc.GenYamlTree(kubectl, "./")
}
```
This will generate a whole series of files, one for each command in the tree, in the directory specified (in this case "./")
## Generate yaml docs for a single command
You may wish to have more control over the output, or only generate for a single command, instead of the entire command tree. If this is the case you may prefer to `GenYaml` instead of `GenYamlTree`
```go
out := new(bytes.Buffer)
doc.GenYaml(cmd, out)
```
This will write the yaml doc for ONLY "cmd" into the out, buffer.
## Customize the output
Both `GenYaml` and `GenYamlTree` have alternate versions with callbacks to get some control of the output:
```go
func GenYamlTreeCustom(cmd *Command, dir string, filePrepender, linkHandler func(string) string) error {
//...
}
```
```go
func GenYamlCustom(cmd *Command, out *bytes.Buffer, linkHandler func(string) string) error {
//...
}
```
The `filePrepender` will prepend the return value given the full filepath to the rendered Yaml file. A common use case is to add front matter to use the generated documentation with [Hugo](http://gohugo.io/):
```go
const fmTemplate = `---
date: %s
title: "%s"
slug: %s
url: %s
---
`
filePrepender := func(filename string) string {
now := time.Now().Format(time.RFC3339)
name := filepath.Base(filename)
base := strings.TrimSuffix(name, path.Ext(name))
url := "/commands/" + strings.ToLower(base) + "/"
return fmt.Sprintf(fmTemplate, now, strings.Replace(base, "_", " ", -1), base, url)
}
```
The `linkHandler` can be used to customize the rendered internal links to the commands, given a filename:
```go
linkHandler := func(name string) string {
base := strings.TrimSuffix(name, path.Ext(name))
return "/commands/" + strings.ToLower(base) + "/"
}
```

View File

@ -0,0 +1,88 @@
package doc
import (
"bytes"
"fmt"
"os"
"strings"
"testing"
)
var _ = fmt.Println
var _ = os.Stderr
func TestGenYamlDoc(t *testing.T) {
c := initializeWithRootCmd()
// Need two commands to run the command alphabetical sort
cmdEcho.AddCommand(cmdTimes, cmdEchoSub, cmdDeprecated)
c.AddCommand(cmdPrint, cmdEcho)
cmdRootWithRun.PersistentFlags().StringVarP(&flags2a, "rootflag", "r", "two", strtwoParentHelp)
out := new(bytes.Buffer)
// We generate on s subcommand so we have both subcommands and parents
if err := GenYaml(cmdEcho, out); err != nil {
t.Fatal(err)
}
found := out.String()
// Our description
expected := cmdEcho.Long
if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
}
// Better have our example
expected = cmdEcho.Example
if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
}
// A local flag
expected = "boolone"
if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
}
// persistent flag on parent
expected = "rootflag"
if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
}
// We better output info about our parent
expected = cmdRootWithRun.Short
if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
}
// And about subcommands
expected = cmdEchoSub.Short
if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
}
unexpected := cmdDeprecated.Short
if strings.Contains(found, unexpected) {
t.Errorf("Unexpected response.\nFound: %v\nBut should not have!!\n", unexpected)
}
}
func TestGenYamlNoTag(t *testing.T) {
c := initializeWithRootCmd()
// Need two commands to run the command alphabetical sort
cmdEcho.AddCommand(cmdTimes, cmdEchoSub, cmdDeprecated)
c.AddCommand(cmdPrint, cmdEcho)
c.DisableAutoGenTag = true
cmdRootWithRun.PersistentFlags().StringVarP(&flags2a, "rootflag", "r", "two", strtwoParentHelp)
out := new(bytes.Buffer)
if err := GenYaml(c, out); err != nil {
t.Fatal(err)
}
found := out.String()
unexpected := "Auto generated"
checkStringOmits(t, found, unexpected)
}

View File

@ -1,4 +1,6 @@
[![Build Status](https://travis-ci.org/spf13/pflag.svg?branch=master)](https://travis-ci.org/spf13/pflag)
[![Go Report Card](https://goreportcard.com/badge/github.com/spf13/pflag)](https://goreportcard.com/report/github.com/spf13/pflag)
[![GoDoc](https://godoc.org/github.com/spf13/pflag?status.svg)](https://godoc.org/github.com/spf13/pflag)
## Description
@ -106,9 +108,9 @@ that give one-letter shorthands for flags. You can use these by appending
var ip = flag.IntP("flagname", "f", 1234, "help message")
var flagvar bool
func init() {
flag.BoolVarP("boolname", "b", true, "help message")
flag.BoolVarP(&flagvar, "boolname", "b", true, "help message")
}
flag.VarP(&flagVar, "varname", "v", 1234, "help message")
flag.VarP(&flagVal, "varname", "v", "help message")
```
Shorthand letters can be used with single dashes on the command line.
@ -268,8 +270,8 @@ func main() {
You can see the full reference documentation of the pflag package
[at godoc.org][3], or through go's standard documentation system by
running `godoc -http=:6060` and browsing to
[http://localhost:6060/pkg/github.com/ogier/pflag][2] after
[http://localhost:6060/pkg/github.com/spf13/pflag][2] after
installation.
[2]: http://localhost:6060/pkg/github.com/ogier/pflag
[3]: http://godoc.org/github.com/ogier/pflag
[2]: http://localhost:6060/pkg/github.com/spf13/pflag
[3]: http://godoc.org/github.com/spf13/pflag

View File

@ -0,0 +1,147 @@
package pflag
import (
"io"
"strconv"
"strings"
)
// -- boolSlice Value
type boolSliceValue struct {
value *[]bool
changed bool
}
func newBoolSliceValue(val []bool, p *[]bool) *boolSliceValue {
bsv := new(boolSliceValue)
bsv.value = p
*bsv.value = val
return bsv
}
// Set converts, and assigns, the comma-separated boolean argument string representation as the []bool value of this flag.
// If Set is called on a flag that already has a []bool assigned, the newly converted values will be appended.
func (s *boolSliceValue) Set(val string) error {
// remove all quote characters
rmQuote := strings.NewReplacer(`"`, "", `'`, "", "`", "")
// read flag arguments with CSV parser
boolStrSlice, err := readAsCSV(rmQuote.Replace(val))
if err != nil && err != io.EOF {
return err
}
// parse boolean values into slice
out := make([]bool, 0, len(boolStrSlice))
for _, boolStr := range boolStrSlice {
b, err := strconv.ParseBool(strings.TrimSpace(boolStr))
if err != nil {
return err
}
out = append(out, b)
}
if !s.changed {
*s.value = out
} else {
*s.value = append(*s.value, out...)
}
s.changed = true
return nil
}
// Type returns a string that uniquely represents this flag's type.
func (s *boolSliceValue) Type() string {
return "boolSlice"
}
// String defines a "native" format for this boolean slice flag value.
func (s *boolSliceValue) String() string {
boolStrSlice := make([]string, len(*s.value))
for i, b := range *s.value {
boolStrSlice[i] = strconv.FormatBool(b)
}
out, _ := writeAsCSV(boolStrSlice)
return "[" + out + "]"
}
func boolSliceConv(val string) (interface{}, error) {
val = strings.Trim(val, "[]")
// Empty string would cause a slice with one (empty) entry
if len(val) == 0 {
return []bool{}, nil
}
ss := strings.Split(val, ",")
out := make([]bool, len(ss))
for i, t := range ss {
var err error
out[i], err = strconv.ParseBool(t)
if err != nil {
return nil, err
}
}
return out, nil
}
// GetBoolSlice returns the []bool value of a flag with the given name.
func (f *FlagSet) GetBoolSlice(name string) ([]bool, error) {
val, err := f.getFlagType(name, "boolSlice", boolSliceConv)
if err != nil {
return []bool{}, err
}
return val.([]bool), nil
}
// BoolSliceVar defines a boolSlice flag with specified name, default value, and usage string.
// The argument p points to a []bool variable in which to store the value of the flag.
func (f *FlagSet) BoolSliceVar(p *[]bool, name string, value []bool, usage string) {
f.VarP(newBoolSliceValue(value, p), name, "", usage)
}
// BoolSliceVarP is like BoolSliceVar, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) BoolSliceVarP(p *[]bool, name, shorthand string, value []bool, usage string) {
f.VarP(newBoolSliceValue(value, p), name, shorthand, usage)
}
// BoolSliceVar defines a []bool flag with specified name, default value, and usage string.
// The argument p points to a []bool variable in which to store the value of the flag.
func BoolSliceVar(p *[]bool, name string, value []bool, usage string) {
CommandLine.VarP(newBoolSliceValue(value, p), name, "", usage)
}
// BoolSliceVarP is like BoolSliceVar, but accepts a shorthand letter that can be used after a single dash.
func BoolSliceVarP(p *[]bool, name, shorthand string, value []bool, usage string) {
CommandLine.VarP(newBoolSliceValue(value, p), name, shorthand, usage)
}
// BoolSlice defines a []bool flag with specified name, default value, and usage string.
// The return value is the address of a []bool variable that stores the value of the flag.
func (f *FlagSet) BoolSlice(name string, value []bool, usage string) *[]bool {
p := []bool{}
f.BoolSliceVarP(&p, name, "", value, usage)
return &p
}
// BoolSliceP is like BoolSlice, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) BoolSliceP(name, shorthand string, value []bool, usage string) *[]bool {
p := []bool{}
f.BoolSliceVarP(&p, name, shorthand, value, usage)
return &p
}
// BoolSlice defines a []bool flag with specified name, default value, and usage string.
// The return value is the address of a []bool variable that stores the value of the flag.
func BoolSlice(name string, value []bool, usage string) *[]bool {
return CommandLine.BoolSliceP(name, "", value, usage)
}
// BoolSliceP is like BoolSlice, but accepts a shorthand letter that can be used after a single dash.
func BoolSliceP(name, shorthand string, value []bool, usage string) *[]bool {
return CommandLine.BoolSliceP(name, shorthand, value, usage)
}

View File

@ -0,0 +1,215 @@
package pflag
import (
"fmt"
"strconv"
"strings"
"testing"
)
func setUpBSFlagSet(bsp *[]bool) *FlagSet {
f := NewFlagSet("test", ContinueOnError)
f.BoolSliceVar(bsp, "bs", []bool{}, "Command separated list!")
return f
}
func setUpBSFlagSetWithDefault(bsp *[]bool) *FlagSet {
f := NewFlagSet("test", ContinueOnError)
f.BoolSliceVar(bsp, "bs", []bool{false, true}, "Command separated list!")
return f
}
func TestEmptyBS(t *testing.T) {
var bs []bool
f := setUpBSFlagSet(&bs)
err := f.Parse([]string{})
if err != nil {
t.Fatal("expected no error; got", err)
}
getBS, err := f.GetBoolSlice("bs")
if err != nil {
t.Fatal("got an error from GetBoolSlice():", err)
}
if len(getBS) != 0 {
t.Fatalf("got bs %v with len=%d but expected length=0", getBS, len(getBS))
}
}
func TestBS(t *testing.T) {
var bs []bool
f := setUpBSFlagSet(&bs)
vals := []string{"1", "F", "TRUE", "0"}
arg := fmt.Sprintf("--bs=%s", strings.Join(vals, ","))
err := f.Parse([]string{arg})
if err != nil {
t.Fatal("expected no error; got", err)
}
for i, v := range bs {
b, err := strconv.ParseBool(vals[i])
if err != nil {
t.Fatalf("got error: %v", err)
}
if b != v {
t.Fatalf("expected is[%d] to be %s but got: %t", i, vals[i], v)
}
}
getBS, err := f.GetBoolSlice("bs")
if err != nil {
t.Fatalf("got error: %v", err)
}
for i, v := range getBS {
b, err := strconv.ParseBool(vals[i])
if err != nil {
t.Fatalf("got error: %v", err)
}
if b != v {
t.Fatalf("expected bs[%d] to be %s but got: %t from GetBoolSlice", i, vals[i], v)
}
}
}
func TestBSDefault(t *testing.T) {
var bs []bool
f := setUpBSFlagSetWithDefault(&bs)
vals := []string{"false", "T"}
err := f.Parse([]string{})
if err != nil {
t.Fatal("expected no error; got", err)
}
for i, v := range bs {
b, err := strconv.ParseBool(vals[i])
if err != nil {
t.Fatalf("got error: %v", err)
}
if b != v {
t.Fatalf("expected bs[%d] to be %t from GetBoolSlice but got: %t", i, b, v)
}
}
getBS, err := f.GetBoolSlice("bs")
if err != nil {
t.Fatal("got an error from GetBoolSlice():", err)
}
for i, v := range getBS {
b, err := strconv.ParseBool(vals[i])
if err != nil {
t.Fatal("got an error from GetBoolSlice():", err)
}
if b != v {
t.Fatalf("expected bs[%d] to be %t from GetBoolSlice but got: %t", i, b, v)
}
}
}
func TestBSWithDefault(t *testing.T) {
var bs []bool
f := setUpBSFlagSetWithDefault(&bs)
vals := []string{"FALSE", "1"}
arg := fmt.Sprintf("--bs=%s", strings.Join(vals, ","))
err := f.Parse([]string{arg})
if err != nil {
t.Fatal("expected no error; got", err)
}
for i, v := range bs {
b, err := strconv.ParseBool(vals[i])
if err != nil {
t.Fatalf("got error: %v", err)
}
if b != v {
t.Fatalf("expected bs[%d] to be %t but got: %t", i, b, v)
}
}
getBS, err := f.GetBoolSlice("bs")
if err != nil {
t.Fatal("got an error from GetBoolSlice():", err)
}
for i, v := range getBS {
b, err := strconv.ParseBool(vals[i])
if err != nil {
t.Fatalf("got error: %v", err)
}
if b != v {
t.Fatalf("expected bs[%d] to be %t from GetBoolSlice but got: %t", i, b, v)
}
}
}
func TestBSCalledTwice(t *testing.T) {
var bs []bool
f := setUpBSFlagSet(&bs)
in := []string{"T,F", "T"}
expected := []bool{true, false, true}
argfmt := "--bs=%s"
arg1 := fmt.Sprintf(argfmt, in[0])
arg2 := fmt.Sprintf(argfmt, in[1])
err := f.Parse([]string{arg1, arg2})
if err != nil {
t.Fatal("expected no error; got", err)
}
for i, v := range bs {
if expected[i] != v {
t.Fatalf("expected bs[%d] to be %t but got %t", i, expected[i], v)
}
}
}
func TestBSBadQuoting(t *testing.T) {
tests := []struct {
Want []bool
FlagArg []string
}{
{
Want: []bool{true, false, true},
FlagArg: []string{"1", "0", "true"},
},
{
Want: []bool{true, false},
FlagArg: []string{"True", "F"},
},
{
Want: []bool{true, false},
FlagArg: []string{"T", "0"},
},
{
Want: []bool{true, false},
FlagArg: []string{"1", "0"},
},
{
Want: []bool{true, false, false},
FlagArg: []string{"true,false", "false"},
},
{
Want: []bool{true, false, false, true, false, true, false},
FlagArg: []string{`"true,false,false,1,0, T"`, " false "},
},
{
Want: []bool{false, false, true, false, true, false, true},
FlagArg: []string{`"0, False, T,false , true,F"`, "true"},
},
}
for i, test := range tests {
var bs []bool
f := setUpBSFlagSet(&bs)
if err := f.Parse([]string{fmt.Sprintf("--bs=%s", strings.Join(test.FlagArg, ","))}); err != nil {
t.Fatalf("flag parsing failed with error: %s\nparsing:\t%#v\nwant:\t\t%#v",
err, test.FlagArg, test.Want[i])
}
for j, b := range bs {
if b != test.Want[j] {
t.Fatalf("bad value parsed for test %d on bool %d:\nwant:\t%t\ngot:\t%t", i, j, test.Want[j], b)
}
}
}
}

View File

@ -6,7 +6,6 @@ package pflag
import (
"bytes"
"fmt"
"strconv"
"testing"
)
@ -48,7 +47,7 @@ func (v *triStateValue) String() string {
if *v == triStateMaybe {
return strTriStateMaybe
}
return fmt.Sprintf("%v", bool(*v == triStateTrue))
return strconv.FormatBool(*v == triStateTrue)
}
// The type of the flag as required by the pflag.Value interface

View File

@ -1,13 +1,10 @@
package pflag
import (
"fmt"
"os"
"testing"
)
var _ = fmt.Printf
func setUpCount(c *int) *FlagSet {
f := NewFlagSet("test", ContinueOnError)
f.CountVarP(c, "verbose", "v", "a counter")

View File

@ -16,9 +16,9 @@ pflag is a drop-in replacement of Go's native flag package. If you import
pflag under the name "flag" then all code should continue to function
with no changes.
import flag "github.com/ogier/pflag"
import flag "github.com/spf13/pflag"
There is one exception to this: if you directly instantiate the Flag struct
There is one exception to this: if you directly instantiate the Flag struct
there is one more field "Shorthand" that you will need to set.
Most code never instantiates this struct directly, and instead uses
functions such as String(), BoolVar(), and Var(), and is therefore
@ -134,14 +134,21 @@ type FlagSet struct {
// a custom error handler.
Usage func()
// SortFlags is used to indicate, if user wants to have sorted flags in
// help/usage messages.
SortFlags bool
name string
parsed bool
actual map[NormalizedName]*Flag
orderedActual []*Flag
sortedActual []*Flag
formal map[NormalizedName]*Flag
orderedFormal []*Flag
sortedFormal []*Flag
shorthands map[byte]*Flag
args []string // arguments after flags
argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no --
exitOnError bool // does the program exit if there's an error?
errorHandling ErrorHandling
output io.Writer // nil means stderr; use out() accessor
interspersed bool // allow interspersed option/non-option args
@ -156,7 +163,7 @@ type Flag struct {
Value Value // value as set
DefValue string // default value (as text); for usage message
Changed bool // If the user set the value (or if left to default)
NoOptDefVal string //default value (as text); if the flag is on the command line without any options
NoOptDefVal string // default value (as text); if the flag is on the command line without any options
Deprecated string // If this flag is deprecated, this string is the new or now thing to use
Hidden bool // used by cobra.Command to allow flags to be hidden from help/usage text
ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use
@ -194,11 +201,13 @@ func sortFlags(flags map[NormalizedName]*Flag) []*Flag {
// "--getUrl" which may also be translated to "geturl" and everything will work.
func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) {
f.normalizeNameFunc = n
for k, v := range f.formal {
delete(f.formal, k)
nname := f.normalizeFlagName(string(k))
f.formal[nname] = v
f.sortedFormal = f.sortedFormal[:0]
for k, v := range f.orderedFormal {
delete(f.formal, NormalizedName(v.Name))
nname := f.normalizeFlagName(v.Name)
v.Name = string(nname)
f.formal[nname] = v
f.orderedFormal[k] = v
}
}
@ -229,10 +238,25 @@ func (f *FlagSet) SetOutput(output io.Writer) {
f.output = output
}
// VisitAll visits the flags in lexicographical order, calling fn for each.
// VisitAll visits the flags in lexicographical order or
// in primordial order if f.SortFlags is false, calling fn for each.
// It visits all flags, even those not set.
func (f *FlagSet) VisitAll(fn func(*Flag)) {
for _, flag := range sortFlags(f.formal) {
if len(f.formal) == 0 {
return
}
var flags []*Flag
if f.SortFlags {
if len(f.formal) != len(f.sortedFormal) {
f.sortedFormal = sortFlags(f.formal)
}
flags = f.sortedFormal
} else {
flags = f.orderedFormal
}
for _, flag := range flags {
fn(flag)
}
}
@ -253,22 +277,39 @@ func (f *FlagSet) HasAvailableFlags() bool {
return false
}
// VisitAll visits the command-line flags in lexicographical order, calling
// fn for each. It visits all flags, even those not set.
// VisitAll visits the command-line flags in lexicographical order or
// in primordial order if f.SortFlags is false, calling fn for each.
// It visits all flags, even those not set.
func VisitAll(fn func(*Flag)) {
CommandLine.VisitAll(fn)
}
// Visit visits the flags in lexicographical order, calling fn for each.
// Visit visits the flags in lexicographical order or
// in primordial order if f.SortFlags is false, calling fn for each.
// It visits only those flags that have been set.
func (f *FlagSet) Visit(fn func(*Flag)) {
for _, flag := range sortFlags(f.actual) {
if len(f.actual) == 0 {
return
}
var flags []*Flag
if f.SortFlags {
if len(f.actual) != len(f.sortedActual) {
f.sortedActual = sortFlags(f.actual)
}
flags = f.sortedActual
} else {
flags = f.orderedActual
}
for _, flag := range flags {
fn(flag)
}
}
// Visit visits the command-line flags in lexicographical order, calling fn
// for each. It visits only those flags that have been set.
// Visit visits the command-line flags in lexicographical order or
// in primordial order if f.SortFlags is false, calling fn for each.
// It visits only those flags that have been set.
func Visit(fn func(*Flag)) {
CommandLine.Visit(fn)
}
@ -373,6 +414,7 @@ func (f *FlagSet) Set(name, value string) error {
f.actual = make(map[NormalizedName]*Flag)
}
f.actual[normalName] = flag
f.orderedActual = append(f.orderedActual, flag)
flag.Changed = true
if len(flag.Deprecated) > 0 {
fmt.Fprintf(os.Stderr, "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
@ -416,7 +458,7 @@ func Set(name, value string) error {
// otherwise, the default values of all defined flags in the set.
func (f *FlagSet) PrintDefaults() {
usages := f.FlagUsages()
fmt.Fprintf(f.out(), "%s", usages)
fmt.Fprint(f.out(), usages)
}
// defaultIsZeroValue returns true if the default value for this flag represents
@ -487,9 +529,76 @@ func UnquoteUsage(flag *Flag) (name string, usage string) {
return
}
// FlagUsages Returns a string containing the usage information for all flags in
// the FlagSet
func (f *FlagSet) FlagUsages() string {
// Splits the string `s` on whitespace into an initial substring up to
// `i` runes in length and the remainder. Will go `slop` over `i` if
// that encompasses the entire string (which allows the caller to
// avoid short orphan words on the final line).
func wrapN(i, slop int, s string) (string, string) {
if i+slop > len(s) {
return s, ""
}
w := strings.LastIndexAny(s[:i], " \t")
if w <= 0 {
return s, ""
}
return s[:w], s[w+1:]
}
// Wraps the string `s` to a maximum width `w` with leading indent
// `i`. The first line is not indented (this is assumed to be done by
// caller). Pass `w` == 0 to do no wrapping
func wrap(i, w int, s string) string {
if w == 0 {
return s
}
// space between indent i and end of line width w into which
// we should wrap the text.
wrap := w - i
var r, l string
// Not enough space for sensible wrapping. Wrap as a block on
// the next line instead.
if wrap < 24 {
i = 16
wrap = w - i
r += "\n" + strings.Repeat(" ", i)
}
// If still not enough space then don't even try to wrap.
if wrap < 24 {
return s
}
// Try to avoid short orphan words on the final line, by
// allowing wrapN to go a bit over if that would fit in the
// remainder of the line.
slop := 5
wrap = wrap - slop
// Handle first line, which is indented by the caller (or the
// special case above)
l, s = wrapN(wrap, slop, s)
r = r + l
// Now wrap the rest
for s != "" {
var t string
t, s = wrapN(wrap, slop, s)
r = r + "\n" + strings.Repeat(" ", i) + t
}
return r
}
// FlagUsagesWrapped returns a string containing the usage information
// for all flags in the FlagSet. Wrapped to `cols` columns (0 for no
// wrapping)
func (f *FlagSet) FlagUsagesWrapped(cols int) string {
x := new(bytes.Buffer)
lines := make([]string, 0, len(f.formal))
@ -514,7 +623,7 @@ func (f *FlagSet) FlagUsages() string {
if len(flag.NoOptDefVal) > 0 {
switch flag.Value.Type() {
case "string":
line += fmt.Sprintf("[=%q]", flag.NoOptDefVal)
line += fmt.Sprintf("[=\"%s\"]", flag.NoOptDefVal)
case "bool":
if flag.NoOptDefVal != "true" {
line += fmt.Sprintf("[=%s]", flag.NoOptDefVal)
@ -546,12 +655,19 @@ func (f *FlagSet) FlagUsages() string {
for _, line := range lines {
sidx := strings.Index(line, "\x00")
spacing := strings.Repeat(" ", maxlen-sidx)
fmt.Fprintln(x, line[:sidx], spacing, line[sidx+1:])
// maxlen + 2 comes from + 1 for the \x00 and + 1 for the (deliberate) off-by-one in maxlen-sidx
fmt.Fprintln(x, line[:sidx], spacing, wrap(maxlen+2, cols, line[sidx+1:]))
}
return x.String()
}
// FlagUsages returns a string containing the usage information for all flags in
// the FlagSet
func (f *FlagSet) FlagUsages() string {
return f.FlagUsagesWrapped(0)
}
// PrintDefaults prints to standard error the default values of all defined command-line flags.
func PrintDefaults() {
CommandLine.PrintDefaults()
@ -635,7 +751,7 @@ func (f *FlagSet) VarPF(value Value, name, shorthand, usage string) *Flag {
// VarP is like Var, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) VarP(value Value, name, shorthand, usage string) {
_ = f.VarPF(value, name, shorthand, usage)
f.VarPF(value, name, shorthand, usage)
}
// AddFlag will add the flag to the FlagSet
@ -655,6 +771,7 @@ func (f *FlagSet) AddFlag(flag *Flag) {
flag.Name = string(normalizedFlagName)
f.formal[normalizedFlagName] = flag
f.orderedFormal = append(f.orderedFormal, flag)
if len(flag.Shorthand) == 0 {
return
@ -733,6 +850,7 @@ func (f *FlagSet) setFlag(flag *Flag, value string, origArg string) error {
f.actual = make(map[NormalizedName]*Flag)
}
f.actual[f.normalizeFlagName(flag.Name)] = flag
f.orderedActual = append(f.orderedActual, flag)
flag.Changed = true
if len(flag.Deprecated) > 0 {
fmt.Fprintf(os.Stderr, "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
@ -752,7 +870,7 @@ func containsShorthand(arg, shorthand string) bool {
return strings.Contains(arg, shorthand)
}
func (f *FlagSet) parseLongArg(s string, args []string) (a []string, err error) {
func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []string, err error) {
a = args
name := s[2:]
if len(name) == 0 || name[0] == '-' || name[0] == '=' {
@ -786,11 +904,11 @@ func (f *FlagSet) parseLongArg(s string, args []string) (a []string, err error)
err = f.failf("flag needs an argument: %s", s)
return
}
err = f.setFlag(flag, value, s)
err = fn(flag, value, s)
return
}
func (f *FlagSet) parseSingleShortArg(shorthands string, args []string) (outShorts string, outArgs []string, err error) {
func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parseFunc) (outShorts string, outArgs []string, err error) {
if strings.HasPrefix(shorthands, "test.") {
return
}
@ -825,16 +943,16 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string) (outShor
err = f.failf("flag needs an argument: %q in -%s", c, shorthands)
return
}
err = f.setFlag(flag, value, shorthands)
err = fn(flag, value, shorthands)
return
}
func (f *FlagSet) parseShortArg(s string, args []string) (a []string, err error) {
func (f *FlagSet) parseShortArg(s string, args []string, fn parseFunc) (a []string, err error) {
a = args
shorthands := s[1:]
for len(shorthands) > 0 {
shorthands, a, err = f.parseSingleShortArg(shorthands, args)
shorthands, a, err = f.parseSingleShortArg(shorthands, args, fn)
if err != nil {
return
}
@ -843,7 +961,7 @@ func (f *FlagSet) parseShortArg(s string, args []string) (a []string, err error)
return
}
func (f *FlagSet) parseArgs(args []string) (err error) {
func (f *FlagSet) parseArgs(args []string, fn parseFunc) (err error) {
for len(args) > 0 {
s := args[0]
args = args[1:]
@ -863,9 +981,9 @@ func (f *FlagSet) parseArgs(args []string) (err error) {
f.args = append(f.args, args...)
break
}
args, err = f.parseLongArg(s, args)
args, err = f.parseLongArg(s, args, fn)
} else {
args, err = f.parseShortArg(s, args)
args, err = f.parseShortArg(s, args, fn)
}
if err != nil {
return
@ -881,7 +999,41 @@ func (f *FlagSet) parseArgs(args []string) (err error) {
func (f *FlagSet) Parse(arguments []string) error {
f.parsed = true
f.args = make([]string, 0, len(arguments))
err := f.parseArgs(arguments)
assign := func(flag *Flag, value, origArg string) error {
return f.setFlag(flag, value, origArg)
}
err := f.parseArgs(arguments, assign)
if err != nil {
switch f.errorHandling {
case ContinueOnError:
return err
case ExitOnError:
os.Exit(2)
case PanicOnError:
panic(err)
}
}
return nil
}
type parseFunc func(flag *Flag, value, origArg string) error
// ParseAll parses flag definitions from the argument list, which should not
// include the command name. The arguments for fn are flag and value. Must be
// called after all flags in the FlagSet are defined and before flags are
// accessed by the program. The return value will be ErrHelp if -help was set
// but not defined.
func (f *FlagSet) ParseAll(arguments []string, fn func(flag *Flag, value string) error) error {
f.parsed = true
f.args = make([]string, 0, len(arguments))
assign := func(flag *Flag, value, origArg string) error {
return fn(flag, value)
}
err := f.parseArgs(arguments, assign)
if err != nil {
switch f.errorHandling {
case ContinueOnError:
@ -907,6 +1059,14 @@ func Parse() {
CommandLine.Parse(os.Args[1:])
}
// ParseAll parses the command-line flags from os.Args[1:] and called fn for each.
// The arguments for fn are flag and value. Must be called after all flags are
// defined and before flags are accessed by the program.
func ParseAll(fn func(flag *Flag, value string) error) {
// Ignore errors; CommandLine is set for ExitOnError.
CommandLine.ParseAll(os.Args[1:], fn)
}
// SetInterspersed sets whether to support interspersed option/non-option arguments.
func SetInterspersed(interspersed bool) {
CommandLine.SetInterspersed(interspersed)
@ -920,14 +1080,15 @@ func Parsed() bool {
// CommandLine is the default set of command-line flags, parsed from os.Args.
var CommandLine = NewFlagSet(os.Args[0], ExitOnError)
// NewFlagSet returns a new, empty flag set with the specified name and
// error handling property.
// NewFlagSet returns a new, empty flag set with the specified name,
// error handling property and SortFlags set to true.
func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet {
f := &FlagSet{
name: name,
errorHandling: errorHandling,
argsLenAtDash: -1,
interspersed: true,
SortFlags: true,
}
return f
}

View File

@ -333,6 +333,59 @@ func testParse(f *FlagSet, t *testing.T) {
}
}
func testParseAll(f *FlagSet, t *testing.T) {
if f.Parsed() {
t.Error("f.Parse() = true before Parse")
}
f.BoolP("boola", "a", false, "bool value")
f.BoolP("boolb", "b", false, "bool2 value")
f.BoolP("boolc", "c", false, "bool3 value")
f.BoolP("boold", "d", false, "bool4 value")
f.StringP("stringa", "s", "0", "string value")
f.StringP("stringz", "z", "0", "string value")
f.StringP("stringx", "x", "0", "string value")
f.StringP("stringy", "y", "0", "string value")
f.Lookup("stringx").NoOptDefVal = "1"
args := []string{
"-ab",
"-cs=xx",
"--stringz=something",
"-d=true",
"-x",
"-y",
"ee",
}
want := []string{
"boola", "true",
"boolb", "true",
"boolc", "true",
"stringa", "xx",
"stringz", "something",
"boold", "true",
"stringx", "1",
"stringy", "ee",
}
got := []string{}
store := func(flag *Flag, value string) error {
got = append(got, flag.Name)
if len(value) > 0 {
got = append(got, value)
}
return nil
}
if err := f.ParseAll(args, store); err != nil {
t.Errorf("expected no error, got %s", err)
}
if !f.Parsed() {
t.Errorf("f.Parse() = false after Parse")
}
if !reflect.DeepEqual(got, want) {
t.Errorf("f.ParseAll() fail to restore the args")
t.Errorf("Got: %v", got)
t.Errorf("Want: %v", want)
}
}
func TestShorthand(t *testing.T) {
f := NewFlagSet("shorthand", ContinueOnError)
if f.Parsed() {
@ -398,16 +451,21 @@ func TestParse(t *testing.T) {
testParse(GetCommandLine(), t)
}
func TestParseAll(t *testing.T) {
ResetForTesting(func() { t.Error("bad parse") })
testParseAll(GetCommandLine(), t)
}
func TestFlagSetParse(t *testing.T) {
testParse(NewFlagSet("test", ContinueOnError), t)
}
func TestChangedHelper(t *testing.T) {
f := NewFlagSet("changedtest", ContinueOnError)
_ = f.Bool("changed", false, "changed bool")
_ = f.Bool("settrue", true, "true to true")
_ = f.Bool("setfalse", false, "false to false")
_ = f.Bool("unchanged", false, "unchanged bool")
f.Bool("changed", false, "changed bool")
f.Bool("settrue", true, "true to true")
f.Bool("setfalse", false, "false to false")
f.Bool("unchanged", false, "unchanged bool")
args := []string{"--changed", "--settrue", "--setfalse=false"}
if err := f.Parse(args); err != nil {
@ -946,3 +1004,43 @@ func TestPrintDefaults(t *testing.T) {
t.Errorf("got %q want %q\n", got, defaultOutput)
}
}
func TestVisitAllFlagOrder(t *testing.T) {
fs := NewFlagSet("TestVisitAllFlagOrder", ContinueOnError)
fs.SortFlags = false
// https://github.com/spf13/pflag/issues/120
fs.SetNormalizeFunc(func(f *FlagSet, name string) NormalizedName {
return NormalizedName(name)
})
names := []string{"C", "B", "A", "D"}
for _, name := range names {
fs.Bool(name, false, "")
}
i := 0
fs.VisitAll(func(f *Flag) {
if names[i] != f.Name {
t.Errorf("Incorrect order. Expected %v, got %v", names[i], f.Name)
}
i++
})
}
func TestVisitFlagOrder(t *testing.T) {
fs := NewFlagSet("TestVisitFlagOrder", ContinueOnError)
fs.SortFlags = false
names := []string{"C", "B", "A", "D"}
for _, name := range names {
fs.Bool(name, false, "")
fs.Set(name, "true")
}
i := 0
fs.Visit(func(f *Flag) {
if names[i] != f.Name {
t.Errorf("Incorrect order. Expected %v, got %v", names[i], f.Name)
}
i++
})
}

View File

@ -6,13 +6,10 @@ package pflag
import (
goflag "flag"
"fmt"
"reflect"
"strings"
)
var _ = fmt.Print
// flagValueWrapper implements pflag.Value around a flag.Value. The main
// difference here is the addition of the Type method that returns a string
// name of the type. As this is generally unknown, we approximate that with

View File

@ -6,8 +6,6 @@ import (
"strings"
)
var _ = strings.TrimSpace
// -- net.IP value
type ipValue net.IP

View File

@ -0,0 +1,148 @@
package pflag
import (
"fmt"
"io"
"net"
"strings"
)
// -- ipSlice Value
type ipSliceValue struct {
value *[]net.IP
changed bool
}
func newIPSliceValue(val []net.IP, p *[]net.IP) *ipSliceValue {
ipsv := new(ipSliceValue)
ipsv.value = p
*ipsv.value = val
return ipsv
}
// Set converts, and assigns, the comma-separated IP argument string representation as the []net.IP value of this flag.
// If Set is called on a flag that already has a []net.IP assigned, the newly converted values will be appended.
func (s *ipSliceValue) Set(val string) error {
// remove all quote characters
rmQuote := strings.NewReplacer(`"`, "", `'`, "", "`", "")
// read flag arguments with CSV parser
ipStrSlice, err := readAsCSV(rmQuote.Replace(val))
if err != nil && err != io.EOF {
return err
}
// parse ip values into slice
out := make([]net.IP, 0, len(ipStrSlice))
for _, ipStr := range ipStrSlice {
ip := net.ParseIP(strings.TrimSpace(ipStr))
if ip == nil {
return fmt.Errorf("invalid string being converted to IP address: %s", ipStr)
}
out = append(out, ip)
}
if !s.changed {
*s.value = out
} else {
*s.value = append(*s.value, out...)
}
s.changed = true
return nil
}
// Type returns a string that uniquely represents this flag's type.
func (s *ipSliceValue) Type() string {
return "ipSlice"
}
// String defines a "native" format for this net.IP slice flag value.
func (s *ipSliceValue) String() string {
ipStrSlice := make([]string, len(*s.value))
for i, ip := range *s.value {
ipStrSlice[i] = ip.String()
}
out, _ := writeAsCSV(ipStrSlice)
return "[" + out + "]"
}
func ipSliceConv(val string) (interface{}, error) {
val = strings.Trim(val, "[]")
// Emtpy string would cause a slice with one (empty) entry
if len(val) == 0 {
return []net.IP{}, nil
}
ss := strings.Split(val, ",")
out := make([]net.IP, len(ss))
for i, sval := range ss {
ip := net.ParseIP(strings.TrimSpace(sval))
if ip == nil {
return nil, fmt.Errorf("invalid string being converted to IP address: %s", sval)
}
out[i] = ip
}
return out, nil
}
// GetIPSlice returns the []net.IP value of a flag with the given name
func (f *FlagSet) GetIPSlice(name string) ([]net.IP, error) {
val, err := f.getFlagType(name, "ipSlice", ipSliceConv)
if err != nil {
return []net.IP{}, err
}
return val.([]net.IP), nil
}
// IPSliceVar defines a ipSlice flag with specified name, default value, and usage string.
// The argument p points to a []net.IP variable in which to store the value of the flag.
func (f *FlagSet) IPSliceVar(p *[]net.IP, name string, value []net.IP, usage string) {
f.VarP(newIPSliceValue(value, p), name, "", usage)
}
// IPSliceVarP is like IPSliceVar, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) IPSliceVarP(p *[]net.IP, name, shorthand string, value []net.IP, usage string) {
f.VarP(newIPSliceValue(value, p), name, shorthand, usage)
}
// IPSliceVar defines a []net.IP flag with specified name, default value, and usage string.
// The argument p points to a []net.IP variable in which to store the value of the flag.
func IPSliceVar(p *[]net.IP, name string, value []net.IP, usage string) {
CommandLine.VarP(newIPSliceValue(value, p), name, "", usage)
}
// IPSliceVarP is like IPSliceVar, but accepts a shorthand letter that can be used after a single dash.
func IPSliceVarP(p *[]net.IP, name, shorthand string, value []net.IP, usage string) {
CommandLine.VarP(newIPSliceValue(value, p), name, shorthand, usage)
}
// IPSlice defines a []net.IP flag with specified name, default value, and usage string.
// The return value is the address of a []net.IP variable that stores the value of that flag.
func (f *FlagSet) IPSlice(name string, value []net.IP, usage string) *[]net.IP {
p := []net.IP{}
f.IPSliceVarP(&p, name, "", value, usage)
return &p
}
// IPSliceP is like IPSlice, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) IPSliceP(name, shorthand string, value []net.IP, usage string) *[]net.IP {
p := []net.IP{}
f.IPSliceVarP(&p, name, shorthand, value, usage)
return &p
}
// IPSlice defines a []net.IP flag with specified name, default value, and usage string.
// The return value is the address of a []net.IP variable that stores the value of the flag.
func IPSlice(name string, value []net.IP, usage string) *[]net.IP {
return CommandLine.IPSliceP(name, "", value, usage)
}
// IPSliceP is like IPSlice, but accepts a shorthand letter that can be used after a single dash.
func IPSliceP(name, shorthand string, value []net.IP, usage string) *[]net.IP {
return CommandLine.IPSliceP(name, shorthand, value, usage)
}

View File

@ -0,0 +1,222 @@
package pflag
import (
"fmt"
"net"
"strings"
"testing"
)
func setUpIPSFlagSet(ipsp *[]net.IP) *FlagSet {
f := NewFlagSet("test", ContinueOnError)
f.IPSliceVar(ipsp, "ips", []net.IP{}, "Command separated list!")
return f
}
func setUpIPSFlagSetWithDefault(ipsp *[]net.IP) *FlagSet {
f := NewFlagSet("test", ContinueOnError)
f.IPSliceVar(ipsp, "ips",
[]net.IP{
net.ParseIP("192.168.1.1"),
net.ParseIP("0:0:0:0:0:0:0:1"),
},
"Command separated list!")
return f
}
func TestEmptyIP(t *testing.T) {
var ips []net.IP
f := setUpIPSFlagSet(&ips)
err := f.Parse([]string{})
if err != nil {
t.Fatal("expected no error; got", err)
}
getIPS, err := f.GetIPSlice("ips")
if err != nil {
t.Fatal("got an error from GetIPSlice():", err)
}
if len(getIPS) != 0 {
t.Fatalf("got ips %v with len=%d but expected length=0", getIPS, len(getIPS))
}
}
func TestIPS(t *testing.T) {
var ips []net.IP
f := setUpIPSFlagSet(&ips)
vals := []string{"192.168.1.1", "10.0.0.1", "0:0:0:0:0:0:0:2"}
arg := fmt.Sprintf("--ips=%s", strings.Join(vals, ","))
err := f.Parse([]string{arg})
if err != nil {
t.Fatal("expected no error; got", err)
}
for i, v := range ips {
if ip := net.ParseIP(vals[i]); ip == nil {
t.Fatalf("invalid string being converted to IP address: %s", vals[i])
} else if !ip.Equal(v) {
t.Fatalf("expected ips[%d] to be %s but got: %s from GetIPSlice", i, vals[i], v)
}
}
}
func TestIPSDefault(t *testing.T) {
var ips []net.IP
f := setUpIPSFlagSetWithDefault(&ips)
vals := []string{"192.168.1.1", "0:0:0:0:0:0:0:1"}
err := f.Parse([]string{})
if err != nil {
t.Fatal("expected no error; got", err)
}
for i, v := range ips {
if ip := net.ParseIP(vals[i]); ip == nil {
t.Fatalf("invalid string being converted to IP address: %s", vals[i])
} else if !ip.Equal(v) {
t.Fatalf("expected ips[%d] to be %s but got: %s", i, vals[i], v)
}
}
getIPS, err := f.GetIPSlice("ips")
if err != nil {
t.Fatal("got an error from GetIPSlice")
}
for i, v := range getIPS {
if ip := net.ParseIP(vals[i]); ip == nil {
t.Fatalf("invalid string being converted to IP address: %s", vals[i])
} else if !ip.Equal(v) {
t.Fatalf("expected ips[%d] to be %s but got: %s", i, vals[i], v)
}
}
}
func TestIPSWithDefault(t *testing.T) {
var ips []net.IP
f := setUpIPSFlagSetWithDefault(&ips)
vals := []string{"192.168.1.1", "0:0:0:0:0:0:0:1"}
arg := fmt.Sprintf("--ips=%s", strings.Join(vals, ","))
err := f.Parse([]string{arg})
if err != nil {
t.Fatal("expected no error; got", err)
}
for i, v := range ips {
if ip := net.ParseIP(vals[i]); ip == nil {
t.Fatalf("invalid string being converted to IP address: %s", vals[i])
} else if !ip.Equal(v) {
t.Fatalf("expected ips[%d] to be %s but got: %s", i, vals[i], v)
}
}
getIPS, err := f.GetIPSlice("ips")
if err != nil {
t.Fatal("got an error from GetIPSlice")
}
for i, v := range getIPS {
if ip := net.ParseIP(vals[i]); ip == nil {
t.Fatalf("invalid string being converted to IP address: %s", vals[i])
} else if !ip.Equal(v) {
t.Fatalf("expected ips[%d] to be %s but got: %s", i, vals[i], v)
}
}
}
func TestIPSCalledTwice(t *testing.T) {
var ips []net.IP
f := setUpIPSFlagSet(&ips)
in := []string{"192.168.1.2,0:0:0:0:0:0:0:1", "10.0.0.1"}
expected := []net.IP{net.ParseIP("192.168.1.2"), net.ParseIP("0:0:0:0:0:0:0:1"), net.ParseIP("10.0.0.1")}
argfmt := "ips=%s"
arg1 := fmt.Sprintf(argfmt, in[0])
arg2 := fmt.Sprintf(argfmt, in[1])
err := f.Parse([]string{arg1, arg2})
if err != nil {
t.Fatal("expected no error; got", err)
}
for i, v := range ips {
if !expected[i].Equal(v) {
t.Fatalf("expected ips[%d] to be %s but got: %s", i, expected[i], v)
}
}
}
func TestIPSBadQuoting(t *testing.T) {
tests := []struct {
Want []net.IP
FlagArg []string
}{
{
Want: []net.IP{
net.ParseIP("a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568"),
net.ParseIP("203.107.49.208"),
net.ParseIP("14.57.204.90"),
},
FlagArg: []string{
"a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568",
"203.107.49.208",
"14.57.204.90",
},
},
{
Want: []net.IP{
net.ParseIP("204.228.73.195"),
net.ParseIP("86.141.15.94"),
},
FlagArg: []string{
"204.228.73.195",
"86.141.15.94",
},
},
{
Want: []net.IP{
net.ParseIP("c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f"),
net.ParseIP("4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472"),
},
FlagArg: []string{
"c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f",
"4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472",
},
},
{
Want: []net.IP{
net.ParseIP("5170:f971:cfac:7be3:512a:af37:952c:bc33"),
net.ParseIP("93.21.145.140"),
net.ParseIP("2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca"),
},
FlagArg: []string{
" 5170:f971:cfac:7be3:512a:af37:952c:bc33 , 93.21.145.140 ",
"2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca",
},
},
{
Want: []net.IP{
net.ParseIP("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"),
net.ParseIP("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"),
net.ParseIP("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"),
net.ParseIP("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"),
},
FlagArg: []string{
`"2e5e:66b2:6441:848:5b74:76ea:574c:3a7b, 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b,2e5e:66b2:6441:848:5b74:76ea:574c:3a7b "`,
" 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"},
},
}
for i, test := range tests {
var ips []net.IP
f := setUpIPSFlagSet(&ips)
if err := f.Parse([]string{fmt.Sprintf("--ips=%s", strings.Join(test.FlagArg, ","))}); err != nil {
t.Fatalf("flag parsing failed with error: %s\nparsing:\t%#v\nwant:\t\t%s",
err, test.FlagArg, test.Want[i])
}
for j, b := range ips {
if !b.Equal(test.Want[j]) {
t.Fatalf("bad value parsed for test %d on net.IP %d:\nwant:\t%s\ngot:\t%s", i, j, test.Want[j], b)
}
}
}
}

View File

@ -27,8 +27,6 @@ func (*ipNetValue) Type() string {
return "ipNet"
}
var _ = strings.TrimSpace
func newIPNetValue(val net.IPNet, p *net.IPNet) *ipNetValue {
*p = val
return (*ipNetValue)(p)

View File

@ -1,12 +1,5 @@
package pflag
import (
"fmt"
"strings"
)
var _ = fmt.Fprint
// -- stringArray Value
type stringArrayValue struct {
value *[]string
@ -40,7 +33,7 @@ func (s *stringArrayValue) String() string {
}
func stringArrayConv(sval string) (interface{}, error) {
sval = strings.Trim(sval, "[]")
sval = sval[1 : len(sval)-1]
// An empty string would cause a array with one (empty) string
if len(sval) == 0 {
return []string{}, nil

Some files were not shown because too many files have changed in this diff Show More