diff --git a/lib/db/backend/backend_test.go b/lib/db/backend/backend_test.go index c77b5792d..d755e0cb6 100644 --- a/lib/db/backend/backend_test.go +++ b/lib/db/backend/backend_test.go @@ -14,6 +14,7 @@ import "testing" func testBackendBehavior(t *testing.T, open func() Backend) { t.Run("WriteIsolation", func(t *testing.T) { testWriteIsolation(t, open) }) t.Run("DeleteNonexisten", func(t *testing.T) { testDeleteNonexistent(t, open) }) + t.Run("IteratorClosedDB", func(t *testing.T) { testIteratorClosedDB(t, open) }) } func testWriteIsolation(t *testing.T, open func() Backend) { @@ -51,3 +52,25 @@ func testDeleteNonexistent(t *testing.T, open func() Backend) { t.Error(err) } } + +// Either creating the iterator or the .Error() method of the returned iterator +// should return an error and IsClosed(err) == true. +func testIteratorClosedDB(t *testing.T, open func() Backend) { + db := open() + + _ = db.Put([]byte("a"), []byte("a")) + + db.Close() + + it, err := db.NewPrefixIterator(nil) + if err != nil { + if !IsClosed(err) { + t.Error("NewPrefixIterator: IsClosed(err) == false:", err) + } + return + } + it.Next() + if err := it.Error(); !IsClosed(err) { + t.Error("Next: IsClosed(err) == false:", err) + } +} diff --git a/lib/db/backend/leveldb_backend.go b/lib/db/backend/leveldb_backend.go index e142ca0ba..7ac5c88a0 100644 --- a/lib/db/backend/leveldb_backend.go +++ b/lib/db/backend/leveldb_backend.go @@ -10,6 +10,7 @@ import ( "sync" "github.com/syndtr/goleveldb/leveldb" + "github.com/syndtr/goleveldb/leveldb/iterator" "github.com/syndtr/goleveldb/leveldb/util" ) @@ -65,11 +66,11 @@ func (b *leveldbBackend) Get(key []byte) ([]byte, error) { } func (b *leveldbBackend) NewPrefixIterator(prefix []byte) (Iterator, error) { - return b.ldb.NewIterator(util.BytesPrefix(prefix), nil), nil + return &leveldbIterator{b.ldb.NewIterator(util.BytesPrefix(prefix), nil)}, nil } func (b *leveldbBackend) NewRangeIterator(first, last []byte) (Iterator, error) { - return b.ldb.NewIterator(&util.Range{Start: first, Limit: last}, nil), nil + return &leveldbIterator{b.ldb.NewIterator(&util.Range{Start: first, Limit: last}, nil)}, nil } func (b *leveldbBackend) Put(key, val []byte) error { @@ -158,6 +159,14 @@ func (t *leveldbTransaction) flush() error { return nil } +type leveldbIterator struct { + iterator.Iterator +} + +func (it *leveldbIterator) Error() error { + return wrapLeveldbErr(it.Iterator.Error()) +} + // wrapLeveldbErr wraps errors so that the backend package can recognize them func wrapLeveldbErr(err error) error { if err == nil {