package limiter

import (
	"bytes"
	"context"
	"crypto/rand"
	"fmt"
	"io"
	"testing"

	"github.com/restic/restic/internal/backend"
	"github.com/restic/restic/internal/backend/mock"
	rtest "github.com/restic/restic/internal/test"
)

func randomBytes(t *testing.T, size int) []byte {
	data := make([]byte, size)
	_, err := io.ReadFull(rand.Reader, data)
	rtest.OK(t, err)
	return data
}

func TestLimitBackendSave(t *testing.T) {
	testHandle := backend.Handle{Type: backend.PackFile, Name: "test"}
	data := randomBytes(t, 1234)

	be := mock.NewBackend()
	be.SaveFn = func(ctx context.Context, h backend.Handle, rd backend.RewindReader) error {
		buf := new(bytes.Buffer)
		_, err := io.Copy(buf, rd)
		if err != nil {
			return nil
		}
		if !bytes.Equal(data, buf.Bytes()) {
			return fmt.Errorf("data mismatch")
		}
		return nil
	}
	limiter := NewStaticLimiter(Limits{42 * 1024, 42 * 1024})
	limbe := LimitBackend(be, limiter)

	rd := backend.NewByteReader(data, nil)
	err := limbe.Save(context.TODO(), testHandle, rd)
	rtest.OK(t, err)
}

type tracedReadWriteToCloser struct {
	io.Reader
	io.WriterTo
	Traced bool
}

func newTracedReadWriteToCloser(rd *bytes.Reader) *tracedReadWriteToCloser {
	return &tracedReadWriteToCloser{Reader: rd, WriterTo: rd}
}

func (r *tracedReadWriteToCloser) WriteTo(w io.Writer) (n int64, err error) {
	r.Traced = true
	return r.WriterTo.WriteTo(w)
}

func (r *tracedReadWriteToCloser) Close() error {
	return nil
}

func TestLimitBackendLoad(t *testing.T) {
	testHandle := backend.Handle{Type: backend.PackFile, Name: "test"}
	data := randomBytes(t, 1234)

	for _, test := range []struct {
		innerWriteTo, outerWriteTo bool
	}{{false, false}, {false, true}, {true, false}, {true, true}} {
		be := mock.NewBackend()
		src := newTracedReadWriteToCloser(bytes.NewReader(data))
		be.OpenReaderFn = func(ctx context.Context, h backend.Handle, length int, offset int64) (io.ReadCloser, error) {
			if length != 0 || offset != 0 {
				return nil, fmt.Errorf("Not supported")
			}
			// test both code paths in WriteTo of limitedReadCloser
			if test.innerWriteTo {
				return src, nil
			}
			return newTracedReadCloser(src), nil
		}
		limiter := NewStaticLimiter(Limits{42 * 1024, 42 * 1024})
		limbe := LimitBackend(be, limiter)

		err := limbe.Load(context.TODO(), testHandle, 0, 0, func(rd io.Reader) error {
			dataRead := new(bytes.Buffer)
			// test both Read and WriteTo
			if !test.outerWriteTo {
				rd = newTracedReadCloser(rd)
			}
			_, err := io.Copy(dataRead, rd)
			if err != nil {
				return err
			}
			if !bytes.Equal(data, dataRead.Bytes()) {
				return fmt.Errorf("read broken data")
			}

			return nil
		})
		rtest.OK(t, err)
		rtest.Assert(t, src.Traced == (test.innerWriteTo && test.outerWriteTo),
			"unexpected/missing writeTo call innerWriteTo %v outerWriteTo %v",
			test.innerWriteTo, test.outerWriteTo)
	}
}