lib/osutil: Replace IsDir with TraversesSymlink (fixes #3839)

GitHub-Pull-Request: https://github.com/syncthing/syncthing/pull/3883
LGTM: calmh
This commit is contained in:
Audrius Butkevicius 2017-01-10 07:09:31 +00:00 committed by Jakob Borg
parent 8d2a31e38e
commit 1a1e35d998
5 changed files with 98 additions and 70 deletions

View File

@ -121,7 +121,6 @@ var (
errDevicePaused = errors.New("device is paused")
errDeviceIgnored = errors.New("device is ignored")
errNotRelative = errors.New("not a relative path")
errNotDir = errors.New("parent is not a directory")
)
// NewModel creates and starts a new model. The model starts in read-only mode,
@ -1159,8 +1158,8 @@ func (m *Model) Request(deviceID protocol.DeviceID, folder, name string, offset
return protocol.ErrNoSuchFile
}
if !osutil.IsDir(folderPath, filepath.Dir(name)) {
l.Debugf("%v REQ(in) for file not in dir: %s: %q / %q o=%d s=%d", m, deviceID, folder, name, offset, len(buf))
if err := osutil.TraversesSymlink(folderPath, filepath.Dir(name)); err != nil {
l.Debugf("%v REQ(in) traversal check: %s - %s: %q / %q o=%d s=%d", m, err, deviceID, folder, name, offset, len(buf))
return protocol.ErrNoSuchFile
}

View File

@ -431,8 +431,8 @@ func (f *sendReceiveFolder) pullerIteration(ignores *ignore.Matcher) int {
for _, fi := range processDirectly {
// Verify that the thing we are handling lives inside a directory,
// and not a symlink or empty space.
if !osutil.IsDir(f.dir, filepath.Dir(fi.Name)) {
f.newError(fi.Name, errNotDir)
if err := osutil.TraversesSymlink(f.dir, filepath.Dir(fi.Name)); err != nil {
f.newError(fi.Name, err)
continue
}
@ -520,8 +520,8 @@ nextFile:
// Verify that the thing we are handling lives inside a directory,
// and not a symlink or empty space.
if !osutil.IsDir(f.dir, filepath.Dir(fi.Name)) {
f.newError(fi.Name, errNotDir)
if err := osutil.TraversesSymlink(f.dir, filepath.Dir(fi.Name)); err != nil {
f.newError(fi.Name, err)
continue
}

View File

@ -1,49 +0,0 @@
// Copyright (C) 2016 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package osutil
import (
"os"
"path/filepath"
"strings"
)
// IsDir returns true if base and every path component of name up to and
// including filepath.Join(base, name) is a directory (and not a symlink or
// similar). Base and name must both be clean and name must be relative to
// base.
func IsDir(base, name string) bool {
path := base
info, err := Lstat(path)
if err != nil {
return false
}
if !info.IsDir() {
return false
}
if name == "." {
// The result of calling IsDir("some/where", filepath.Dir("foo"))
return true
}
parts := strings.Split(name, string(os.PathSeparator))
for _, part := range parts {
path = filepath.Join(path, part)
info, err := Lstat(path)
if err != nil {
return false
}
if info.Mode()&os.ModeSymlink != 0 {
return false
}
if !info.IsDir() {
return false
}
}
return true
}

View File

@ -0,0 +1,76 @@
// Copyright (C) 2016 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package osutil
import (
"fmt"
"os"
"path/filepath"
"strings"
)
// TraversesSymlinkError is an error indicating symlink traversal
type TraversesSymlinkError struct {
path string
}
func (e TraversesSymlinkError) Error() string {
return fmt.Sprintf("traverses symlink: %s", e.path)
}
// NotADirectoryError is an error indicating an expected path is not a directory
type NotADirectoryError struct {
path string
}
func (e NotADirectoryError) Error() string {
return fmt.Sprintf("not a directory: %s", e.path)
}
// TraversesSymlink returns an error if base and any path component of name up to and
// including filepath.Join(base, name) traverses a symlink.
// Base and name must both be clean and name must be relative to base.
func TraversesSymlink(base, name string) error {
path := base
info, err := Lstat(path)
if err != nil {
return err
}
if !info.IsDir() {
return &NotADirectoryError{
path: base,
}
}
if name == "." {
// The result of calling TraversesSymlink("some/where", filepath.Dir("foo"))
return nil
}
parts := strings.Split(name, string(os.PathSeparator))
for _, part := range parts {
path = filepath.Join(path, part)
info, err := Lstat(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
if info.Mode()&os.ModeSymlink != 0 {
return &TraversesSymlinkError{
path: strings.TrimPrefix(path, base),
}
}
if !info.IsDir() {
return &NotADirectoryError{
path: strings.TrimPrefix(path, base),
}
}
}
return nil
}

View File

@ -14,7 +14,7 @@ import (
"github.com/syncthing/syncthing/lib/symlinks"
)
func TestIsDir(t *testing.T) {
func TestTraversesSymlink(t *testing.T) {
if !symlinks.Supported {
t.Skip("pointless test")
return
@ -35,40 +35,42 @@ func TestIsDir(t *testing.T) {
}
cases := []struct {
name string
isDir bool
name string
traverses bool
}{
// Exist
{".", true},
{"a", true},
{"a/b", true},
{"a/b/c", true},
{".", false},
{"a", false},
{"a/b", false},
{"a/b/c", false},
// Don't exist
{"x", false},
{"a/x", false},
{"a/b/x", false},
{"a/x/c", false},
// Symlink or behind symlink
{"a/l", false},
{"a/l/c", false},
{"a/l", true},
{"a/l/c", true},
// Non-existing behind a symlink
{"a/l/x", true},
}
for _, tc := range cases {
if res := osutil.IsDir("testdata", tc.name); res != tc.isDir {
t.Errorf("IsDir(%q) = %v, should be %v", tc.name, res, tc.isDir)
if res := osutil.TraversesSymlink("testdata", tc.name); tc.traverses == (res == nil) {
t.Errorf("TraversesSymlink(%q) = %v, should be %v", tc.name, res, tc.traverses)
}
}
}
var isDirResult bool
var traversesSymlinkResult error
func BenchmarkIsDir(b *testing.B) {
func BenchmarkTraversesSymlink(b *testing.B) {
os.RemoveAll("testdata")
defer os.RemoveAll("testdata")
os.MkdirAll("testdata/a/b/c", 0755)
for i := 0; i < b.N; i++ {
isDirResult = osutil.IsDir("testdata", "a/b/c")
traversesSymlinkResult = osutil.TraversesSymlink("testdata", "a/b/c")
}
b.ReportAllocs()