diff --git a/cmd/restic/cmd_backup.go b/cmd/restic/cmd_backup.go index 3f9fece97..da924d239 100644 --- a/cmd/restic/cmd_backup.go +++ b/cmd/restic/cmd_backup.go @@ -556,8 +556,8 @@ func runBackup(opts BackupOptions, gopts GlobalOptions, term *termstatus.Termina var targetFS fs.FS = fs.Local{} if runtime.GOOS == "windows" && opts.UseFsSnapshot { - if !fs.HasSufficientPrivilegesForVSS() { - return errors.Fatal("user doesn't have sufficient privileges to use VSS snapshots\n") + if err = fs.HasSufficientPrivilegesForVSS(); err != nil { + return err } errorHandler := func(item string, err error) error { diff --git a/cmd/restic/integration_test.go b/cmd/restic/integration_test.go index 9c01939ec..66a129598 100644 --- a/cmd/restic/integration_test.go +++ b/cmd/restic/integration_test.go @@ -286,7 +286,7 @@ func TestBackup(t *testing.T) { } func TestBackupWithFilesystemSnapshots(t *testing.T) { - if runtime.GOOS == "windows" && fs.HasSufficientPrivilegesForVSS() { + if runtime.GOOS == "windows" && fs.HasSufficientPrivilegesForVSS() == nil { testBackup(t, true) } } diff --git a/internal/fs/vss.go b/internal/fs/vss.go index a515d75b2..ca0604906 100644 --- a/internal/fs/vss.go +++ b/internal/fs/vss.go @@ -26,8 +26,8 @@ type VssSnapshot struct { } // HasSufficientPrivilegesForVSS returns true if the user is allowed to use VSS. -func HasSufficientPrivilegesForVSS() bool { - return false +func HasSufficientPrivilegesForVSS() error { + return errors.New("VSS snapshots are only supported on windows") } // NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't diff --git a/internal/fs/vss_windows.go b/internal/fs/vss_windows.go index b63ad4cd8..e24dade80 100644 --- a/internal/fs/vss_windows.go +++ b/internal/fs/vss_windows.go @@ -686,10 +686,10 @@ func (p *VssSnapshot) GetSnapshotDeviceObject() string { } // initializeCOMInterface initialize an instance of the VSS COM api -func initializeVssCOMInterface() (*ole.IUnknown, uintptr, error) { +func initializeVssCOMInterface() (*ole.IUnknown, error) { vssInstance, err := loadIVssBackupComponentsConstructor() if err != nil { - return nil, 0, err + return nil, err } // ensure COM is initialized before use @@ -697,22 +697,33 @@ func initializeVssCOMInterface() (*ole.IUnknown, uintptr, error) { var oleIUnknown *ole.IUnknown result, _, _ := vssInstance.Call(uintptr(unsafe.Pointer(&oleIUnknown))) + hresult := HRESULT(result) - return oleIUnknown, result, nil + switch hresult { + case S_OK: + case E_ACCESSDENIED: + return oleIUnknown, newVssError( + "The caller does not have sufficient backup privileges or is not an administrator", + hresult) + default: + return oleIUnknown, newVssError("Failed to create VSS instance", hresult) + } + + if oleIUnknown == nil { + return nil, newVssError("Failed to initialize COM interface", hresult) + } + + return oleIUnknown, nil } -// HasSufficientPrivilegesForVSS returns true if the user is allowed to use VSS. -func HasSufficientPrivilegesForVSS() bool { - oleIUnknown, result, err := initializeVssCOMInterface() +// HasSufficientPrivilegesForVSS returns nil if the user is allowed to use VSS. +func HasSufficientPrivilegesForVSS() error { + oleIUnknown, err := initializeVssCOMInterface() if oleIUnknown != nil { oleIUnknown.Release() } - if err != nil { - return false - } - - return !(HRESULT(result) == E_ACCESSDENIED) + return err } // NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't @@ -734,24 +745,12 @@ func NewVssSnapshot( timeoutInMillis := uint32(timeoutInSeconds * 1000) - oleIUnknown, result, err := initializeVssCOMInterface() - if err != nil { - if oleIUnknown != nil { - oleIUnknown.Release() - } - return VssSnapshot{}, err + oleIUnknown, err := initializeVssCOMInterface() + if oleIUnknown != nil { + defer oleIUnknown.Release() } - defer oleIUnknown.Release() - - switch HRESULT(result) { - case S_OK: - case E_ACCESSDENIED: - return VssSnapshot{}, newVssTextError(fmt.Sprintf("%s (%#x) The caller does not have "+ - "sufficient backup privileges or is not an administrator.", HRESULT(result).Str(), - result)) - default: - return VssSnapshot{}, newVssTextError(fmt.Sprintf("Failed to create VSS instance: %s (%#x)", - HRESULT(result).Str(), result)) + if err != nil { + return VssSnapshot{}, err } comInterface, err := queryInterface(oleIUnknown, UUID_IVSS)