From eca076cf7d4294f4d294c921af39066739471550 Mon Sep 17 00:00:00 2001 From: Simon Frei Date: Tue, 17 Apr 2018 22:53:06 +0200 Subject: [PATCH] lib/osutil: Fix TraversesSymlink with symlinked fs root on windows (fixes #4875) (#4886) --- lib/osutil/traversessymlink.go | 25 +++++--------- lib/osutil/traversessymlink_test.go | 52 +++++++++++++++++++++++++---- 2 files changed, 55 insertions(+), 22 deletions(-) diff --git a/lib/osutil/traversessymlink.go b/lib/osutil/traversessymlink.go index 52d199c37..0ac5c8f5f 100644 --- a/lib/osutil/traversessymlink.go +++ b/lib/osutil/traversessymlink.go @@ -32,29 +32,22 @@ func (e NotADirectoryError) Error() string { return fmt.Sprintf("not a directory: %s", e.path) } -// TraversesSymlink returns an error if base and any path component of name up to and -// including filepath.Join(base, name) traverses a symlink. -// Base and name must both be clean and name must be relative to base. +// TraversesSymlink returns an error if any path component of name (including name +// itself) traverses a symlink. func TraversesSymlink(filesystem fs.Filesystem, name string) error { - base := "." - path := base - info, err := filesystem.Lstat(path) + var err error + name, err = fs.Canonicalize(name) if err != nil { return err } - if !info.IsDir() { - return &NotADirectoryError{ - path: base, - } - } if name == "." { - // The result of calling TraversesSymlink("some/where", filepath.Dir("foo")) + // The result of calling TraversesSymlink(filesystem, filepath.Dir("foo")) return nil } - parts := strings.Split(name, string(fs.PathSeparator)) - for _, part := range parts { + var path string + for _, part := range strings.Split(name, string(fs.PathSeparator)) { path = filepath.Join(path, part) info, err := filesystem.Lstat(path) if err != nil { @@ -65,12 +58,12 @@ func TraversesSymlink(filesystem fs.Filesystem, name string) error { } if info.IsSymlink() { return &TraversesSymlinkError{ - path: strings.TrimPrefix(path, base), + path: path, } } if !info.IsDir() { return &NotADirectoryError{ - path: strings.TrimPrefix(path, base), + path: path, } } } diff --git a/lib/osutil/traversessymlink_test.go b/lib/osutil/traversessymlink_test.go index c1ec67a8a..b69affaac 100644 --- a/lib/osutil/traversessymlink_test.go +++ b/lib/osutil/traversessymlink_test.go @@ -4,12 +4,13 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at https://mozilla.org/MPL/2.0/. -// +build !windows - package osutil_test import ( + "io/ioutil" "os" + "path/filepath" + "runtime" "testing" "github.com/syncthing/syncthing/lib/fs" @@ -17,12 +18,20 @@ import ( ) func TestTraversesSymlink(t *testing.T) { - os.RemoveAll("testdata") - defer os.RemoveAll("testdata") + tmpDir, err := ioutil.TempDir(".", ".test-TraversesSymlink-") + if err != nil { + panic("Failed to create temporary testing dir") + } + defer os.RemoveAll(tmpDir) - fs := fs.NewFilesystem(fs.FilesystemTypeBasic, "testdata") + fs := fs.NewFilesystem(fs.FilesystemTypeBasic, tmpDir) fs.MkdirAll("a/b/c", 0755) - fs.CreateSymlink("b", "a/l") + if err = osutil.DebugSymlinkForTestsOnly(filepath.Join(fs.URI(), "a", "b"), filepath.Join(fs.URI(), "a", "l")); err != nil { + if runtime.GOOS == "windows" { + t.Skip("Symlinks aren't working") + } + t.Fatal(err) + } // a/l -> b, so a/l/c should resolve by normal stat info, err := fs.Lstat("a/l/c") @@ -61,6 +70,37 @@ func TestTraversesSymlink(t *testing.T) { } } +func TestIssue4875(t *testing.T) { + tmpDir, err := ioutil.TempDir("", ".test-Issue4875-") + if err != nil { + panic("Failed to create temporary testing dir") + } + defer os.RemoveAll(tmpDir) + + testFs := fs.NewFilesystem(fs.FilesystemTypeBasic, tmpDir) + testFs.MkdirAll("a/b/c", 0755) + if err = osutil.DebugSymlinkForTestsOnly(filepath.Join(testFs.URI(), "a", "b"), filepath.Join(testFs.URI(), "a", "l")); err != nil { + if runtime.GOOS == "windows" { + t.Skip("Symlinks aren't working") + } + t.Fatal(err) + } + + // a/l -> b, so a/l/c should resolve by normal stat + info, err := testFs.Lstat("a/l/c") + if err != nil { + t.Fatal("unexpected error", err) + } + if !info.IsDir() { + t.Fatal("error in setup, a/l/c should be a directory") + } + + testFs = fs.NewFilesystem(fs.FilesystemTypeBasic, filepath.Join(tmpDir, "a/l")) + if err := osutil.TraversesSymlink(testFs, "."); err != nil { + t.Error(`TraversesSymlink on filesystem with symlink at root returned error for ".":`, err) + } +} + var traversesSymlinkResult error func BenchmarkTraversesSymlink(b *testing.B) {