diff --git a/cmd/restic/lock.go b/cmd/restic/lock.go index 0cea02cfd..188bb1d59 100644 --- a/cmd/restic/lock.go +++ b/cmd/restic/lock.go @@ -11,10 +11,13 @@ import ( "github.com/restic/restic/internal/restic" ) +type lockContext struct { + cancel context.CancelFunc + refreshWG sync.WaitGroup +} + var globalLocks struct { - locks []*restic.Lock - cancelRefresh chan struct{} - refreshWG sync.WaitGroup + locks map[*restic.Lock]*lockContext sync.Mutex sync.Once } @@ -27,6 +30,8 @@ func lockRepoExclusive(ctx context.Context, repo *repository.Repository) (*resti return lockRepository(ctx, repo, true) } +// lockRepository wraps the ctx such that it is cancelled when the repository is unlocked +// cancelling the original context also stops the lock refresh func lockRepository(ctx context.Context, repo *repository.Repository, exclusive bool) (*restic.Lock, context.Context, error) { // make sure that a repository is unlocked properly and after cancel() was // called by the cleanup handler in global.go @@ -45,16 +50,17 @@ func lockRepository(ctx context.Context, repo *repository.Repository, exclusive } debug.Log("create lock %p (exclusive %v)", lock, exclusive) - globalLocks.Lock() - if globalLocks.cancelRefresh == nil { - debug.Log("start goroutine for lock refresh") - globalLocks.cancelRefresh = make(chan struct{}) - globalLocks.refreshWG = sync.WaitGroup{} - globalLocks.refreshWG.Add(1) - go refreshLocks(&globalLocks.refreshWG, globalLocks.cancelRefresh) + ctx, cancel := context.WithCancel(ctx) + lockInfo := &lockContext{ + cancel: cancel, } + lockInfo.refreshWG.Add(2) + refreshChan := make(chan struct{}) - globalLocks.locks = append(globalLocks.locks, lock) + globalLocks.Lock() + globalLocks.locks[lock] = lockInfo + go refreshLocks(ctx, lock, lockInfo, refreshChan) + go monitorLockRefresh(ctx, lock, lockInfo, refreshChan) globalLocks.Unlock() return lock, ctx, err @@ -62,32 +68,76 @@ func lockRepository(ctx context.Context, repo *repository.Repository, exclusive var refreshInterval = 5 * time.Minute -func refreshLocks(wg *sync.WaitGroup, done <-chan struct{}) { - debug.Log("start") - defer func() { - wg.Done() - globalLocks.Lock() - globalLocks.cancelRefresh = nil - globalLocks.Unlock() - }() +// consider a lock refresh failed a bit before the lock actually becomes stale +// the difference allows to compensate for a small time drift between clients. +var refreshabilityTimeout = restic.StaleLockTimeout - refreshInterval*3/2 +func refreshLocks(ctx context.Context, lock *restic.Lock, lockInfo *lockContext, refreshed chan<- struct{}) { + debug.Log("start") ticker := time.NewTicker(refreshInterval) + lastRefresh := lock.Time + + defer func() { + ticker.Stop() + // ensure that the context was cancelled before removing the lock + lockInfo.cancel() + + // remove the lock from the repo + debug.Log("unlocking repository with lock %v", lock) + if err := lock.Unlock(); err != nil { + debug.Log("error while unlocking: %v", err) + Warnf("error while unlocking: %v", err) + } + + lockInfo.refreshWG.Done() + }() for { select { - case <-done: + case <-ctx.Done(): debug.Log("terminate") return case <-ticker.C: + if time.Since(lastRefresh) > refreshabilityTimeout { + // the lock is too old, wait until the expiry monitor cancels the context + continue + } + debug.Log("refreshing locks") - globalLocks.Lock() - for _, lock := range globalLocks.locks { - err := lock.Refresh(context.TODO()) - if err != nil { - Warnf("unable to refresh lock: %v\n", err) + err := lock.Refresh(context.TODO()) + if err != nil { + Warnf("unable to refresh lock: %v\n", err) + } else { + lastRefresh = lock.Time + // inform monitor gorountine about successful refresh + select { + case <-ctx.Done(): + case refreshed <- struct{}{}: } } - globalLocks.Unlock() + } + } +} + +func monitorLockRefresh(ctx context.Context, lock *restic.Lock, lockInfo *lockContext, refreshed <-chan struct{}) { + timer := time.NewTimer(refreshabilityTimeout) + defer func() { + timer.Stop() + lockInfo.cancel() + lockInfo.refreshWG.Done() + }() + + for { + select { + case <-ctx.Done(): + debug.Log("terminate expiry monitoring") + return + case <-refreshed: + // reset timer once the lock was refreshed successfully + timer.Reset(refreshabilityTimeout) + case <-timer.C: + Warnf("Fatal: failed to refresh lock in time\n") + return } } } @@ -98,40 +148,35 @@ func unlockRepo(lock *restic.Lock) { } globalLocks.Lock() - defer globalLocks.Unlock() + lockInfo, exists := globalLocks.locks[lock] + delete(globalLocks.locks, lock) + globalLocks.Unlock() - for i := 0; i < len(globalLocks.locks); i++ { - if lock == globalLocks.locks[i] { - // remove the lock from the repo - debug.Log("unlocking repository with lock %v", lock) - if err := lock.Unlock(); err != nil { - debug.Log("error while unlocking: %v", err) - Warnf("error while unlocking: %v", err) - return - } - - // remove the lock from the list of locks - globalLocks.locks = append(globalLocks.locks[:i], globalLocks.locks[i+1:]...) - return - } + if !exists { + debug.Log("unable to find lock %v in the global list of locks, ignoring", lock) + return } - - debug.Log("unable to find lock %v in the global list of locks, ignoring", lock) + lockInfo.cancel() + lockInfo.refreshWG.Wait() } func unlockAll(code int) (int, error) { globalLocks.Lock() - defer globalLocks.Unlock() - + locks := globalLocks.locks debug.Log("unlocking %d locks", len(globalLocks.locks)) - for _, lock := range globalLocks.locks { - if err := lock.Unlock(); err != nil { - debug.Log("error while unlocking: %v", err) - return code, err - } - debug.Log("successfully removed lock") + for _, lockInfo := range globalLocks.locks { + lockInfo.cancel() + } + globalLocks.locks = make(map[*restic.Lock]*lockContext) + globalLocks.Unlock() + + for _, lockInfo := range locks { + lockInfo.refreshWG.Wait() } - globalLocks.locks = globalLocks.locks[:0] return code, nil } + +func init() { + globalLocks.locks = make(map[*restic.Lock]*lockContext) +} diff --git a/internal/restic/lock.go b/internal/restic/lock.go index 031e8755c..c8079f58d 100644 --- a/internal/restic/lock.go +++ b/internal/restic/lock.go @@ -175,14 +175,14 @@ func (l *Lock) Unlock() error { return l.repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: l.lockID.String()}) } -var staleTimeout = 30 * time.Minute +var StaleLockTimeout = 30 * time.Minute // Stale returns true if the lock is stale. A lock is stale if the timestamp is // older than 30 minutes or if it was created on the current machine and the // process isn't alive any more. func (l *Lock) Stale() bool { debug.Log("testing if lock %v for process %d is stale", l, l.PID) - if time.Since(l.Time) > staleTimeout { + if time.Since(l.Time) > StaleLockTimeout { debug.Log("lock is stale, timestamp is too old: %v\n", l.Time) return true }