package sftp import ( "encoding" "fmt" "sync" "testing" "time" "github.com/stretchr/testify/assert" ) type _testSender struct { sent chan encoding.BinaryMarshaler } func newTestSender() *_testSender { return &_testSender{make(chan encoding.BinaryMarshaler)} } func (s _testSender) sendPacket(p encoding.BinaryMarshaler) error { s.sent <- p return nil } type fakepacket uint32 func (fakepacket) MarshalBinary() ([]byte, error) { return []byte{}, nil } func (fakepacket) UnmarshalBinary([]byte) error { return nil } func (f fakepacket) id() uint32 { return uint32(f) } type pair struct { in fakepacket out fakepacket } // basic test var ttable1 = []pair{ pair{fakepacket(0), fakepacket(0)}, pair{fakepacket(1), fakepacket(1)}, pair{fakepacket(2), fakepacket(2)}, pair{fakepacket(3), fakepacket(3)}, } // outgoing packets out of order var ttable2 = []pair{ pair{fakepacket(0), fakepacket(0)}, pair{fakepacket(1), fakepacket(4)}, pair{fakepacket(2), fakepacket(1)}, pair{fakepacket(3), fakepacket(3)}, pair{fakepacket(4), fakepacket(2)}, } // incoming packets out of order var ttable3 = []pair{ pair{fakepacket(2), fakepacket(0)}, pair{fakepacket(1), fakepacket(1)}, pair{fakepacket(3), fakepacket(2)}, pair{fakepacket(0), fakepacket(3)}, } var tables = [][]pair{ttable1, ttable2, ttable3} func TestPacketManager(t *testing.T) { sender := newTestSender() s := newPktMgr(sender) for i := range tables { table := tables[i] for _, p := range table { s.incomingPacket(p.in) } for _, p := range table { s.readyPacket(p.out) } for i := 0; i < len(table); i++ { pkt := <-sender.sent id := pkt.(fakepacket).id() assert.Equal(t, id, uint32(i)) } } s.close() } func (p sshFxpRemovePacket) String() string { return fmt.Sprintf("RmPct:%d", p.ID) } func (p sshFxpOpenPacket) String() string { return fmt.Sprintf("OpPct:%d", p.ID) } func (p sshFxpWritePacket) String() string { return fmt.Sprintf("WrPct:%d", p.ID) } func (p sshFxpClosePacket) String() string { return fmt.Sprintf("ClPct:%d", p.ID) } // Test what happens when the pool processes a close packet on a file that it // is still reading from. func TestCloseOutOfOrder(t *testing.T) { packets := []requestPacket{ &sshFxpRemovePacket{ID: 0, Filename: "foo"}, &sshFxpOpenPacket{ID: 1}, &sshFxpWritePacket{ID: 2, Handle: "foo"}, &sshFxpWritePacket{ID: 3, Handle: "foo"}, &sshFxpWritePacket{ID: 4, Handle: "foo"}, &sshFxpWritePacket{ID: 5, Handle: "foo"}, &sshFxpClosePacket{ID: 6, Handle: "foo"}, &sshFxpRemovePacket{ID: 7, Filename: "foo"}, } recvChan := make(chan requestPacket, len(packets)+1) sender := newTestSender() pktMgr := newPktMgr(sender) wg := sync.WaitGroup{} wg.Add(len(packets)) runWorker := func(ch requestChan) { go func() { for pkt := range ch { if _, ok := pkt.(*sshFxpWritePacket); ok { // sleep to cause writes to come after close/remove time.Sleep(time.Millisecond) } pktMgr.working.Done() recvChan <- pkt wg.Done() } }() } pktChan := pktMgr.workerChan(runWorker) for _, p := range packets { pktChan <- p } wg.Wait() close(recvChan) received := []requestPacket{} for p := range recvChan { received = append(received, p) } if received[len(received)-2].id() != packets[len(packets)-2].id() { t.Fatal("Packets processed out of order1:", received, packets) } if received[len(received)-1].id() != packets[len(packets)-1].id() { t.Fatal("Packets processed out of order2:", received, packets) } }