Rework XDR encoding

This commit is contained in:
Jakob Borg 2014-02-20 17:40:15 +01:00
parent 727f35b35b
commit 65615385e7

View File

@ -1,8 +1,12 @@
package main package main
import ( import (
"encoding/binary"
"encoding/hex"
"flag"
"log" "log"
"net" "net"
"os"
"sync" "sync"
"time" "time"
@ -10,19 +14,38 @@ import (
) )
type Node struct { type Node struct {
Addresses []Address
Updated time.Time
}
type Address struct {
IP []byte IP []byte
Port uint16 Port uint16
Updated time.Time
} }
var ( var (
nodes = make(map[string]Node) nodes = make(map[string]Node)
lock sync.Mutex lock sync.Mutex
queries = 0 queries = 0
answered = 0
) )
func main() { func main() {
addr, _ := net.ResolveUDPAddr("udp", ":22025") var debug bool
var listen string
var timestamp bool
flag.StringVar(&listen, "listen", ":22025", "Listen address")
flag.BoolVar(&debug, "debug", false, "Enable debug output")
flag.BoolVar(&timestamp, "timestamp", true, "Timestamp the log output")
flag.Parse()
log.SetOutput(os.Stdout)
if !timestamp {
log.SetFlags(0)
}
addr, _ := net.ResolveUDPAddr("udp", listen)
conn, err := net.ListenUDP("udp", addr) conn, err := net.ListenUDP("udp", addr)
if err != nil { if err != nil {
panic(err) panic(err)
@ -41,8 +64,9 @@ func main() {
deleted++ deleted++
} }
} }
log.Printf("Expired %d nodes; %d nodes in registry; %d queries", deleted, len(nodes), queries) log.Printf("Expired %d nodes; %d nodes in registry; %d queries (%d answered)", deleted, len(nodes), queries, answered)
queries = 0 queries = 0
answered = 0
lock.Unlock() lock.Unlock()
} }
@ -50,49 +74,163 @@ func main() {
var buf = make([]byte, 1024) var buf = make([]byte, 1024)
for { for {
buf = buf[:cap(buf)]
n, addr, err := conn.ReadFromUDP(buf) n, addr, err := conn.ReadFromUDP(buf)
if err != nil { if err != nil {
panic(err) panic(err)
} }
pkt, err := discover.DecodePacket(buf[:n]) if n < 4 {
if err != nil { log.Printf("Received short packet (%d bytes)", n)
log.Println("Warning:", err)
continue continue
} }
switch pkt.Magic { buf = buf[:n]
case 0x20121025: magic := binary.BigEndian.Uint32(buf)
// Announcement
lock.Lock() switch magic {
case discover.AnnouncementMagicV1:
var pkt discover.AnnounceV1
err := pkt.UnmarshalXDR(buf)
if err != nil {
log.Println("AnnounceV1 Unmarshal:", err)
log.Println(hex.Dump(buf))
continue
}
if debug {
log.Printf("<- %v %#v", addr, pkt)
}
ip := addr.IP.To4() ip := addr.IP.To4()
if ip == nil { if ip == nil {
ip = addr.IP.To16() ip = addr.IP.To16()
} }
node := Node{ node := Node{
Addresses: []Address{{
IP: ip, IP: ip,
Port: uint16(pkt.Port), Port: pkt.Port,
}},
Updated: time.Now(), Updated: time.Now(),
} }
//log.Println("<-", pkt.ID, node)
nodes[pkt.ID] = node
lock.Unlock()
case 0x19760309:
// Query
lock.Lock() lock.Lock()
node, ok := nodes[pkt.ID] nodes[pkt.NodeID] = node
lock.Unlock()
case discover.QueryMagicV1:
var pkt discover.QueryV1
err := pkt.UnmarshalXDR(buf)
if err != nil {
log.Println("QueryV1 Unmarshal:", err)
log.Println(hex.Dump(buf))
continue
}
if debug {
log.Printf("<- %v %#v", addr, pkt)
}
lock.Lock()
node, ok := nodes[pkt.NodeID]
queries++ queries++
lock.Unlock() lock.Unlock()
if ok {
pkt := discover.Packet{ if ok && len(node.Addresses) > 0 {
Magic: 0x20121025, pkt := discover.AnnounceV1{
ID: pkt.ID, Magic: discover.AnnouncementMagicV1,
Port: node.Port, NodeID: pkt.NodeID,
IP: node.IP, Port: node.Addresses[0].Port,
IP: node.Addresses[0].IP,
} }
_, _, err = conn.WriteMsgUDP(discover.EncodePacket(pkt), nil, addr) if debug {
log.Printf("-> %v %#v", addr, pkt)
}
tb := pkt.MarshalXDR()
_, _, err = conn.WriteMsgUDP(tb, nil, addr)
if err != nil { if err != nil {
log.Println("Warning:", err) log.Println("QueryV1 response write:", err)
} }
lock.Lock()
answered++
lock.Unlock()
}
case discover.AnnouncementMagicV2:
var pkt discover.AnnounceV2
err := pkt.UnmarshalXDR(buf)
if err != nil {
log.Println("AnnounceV2 Unmarshal:", err)
log.Println(hex.Dump(buf))
continue
}
if debug {
log.Printf("<- %v %#v", addr, pkt)
}
ip := addr.IP.To4()
if ip == nil {
ip = addr.IP.To16()
}
var addrs []Address
for _, addr := range pkt.Addresses {
tip := addr.IP
if len(tip) == 0 {
tip = ip
}
addrs = append(addrs, Address{
IP: tip,
Port: addr.Port,
})
}
node := Node{
Addresses: addrs,
Updated: time.Now(),
}
lock.Lock()
nodes[pkt.NodeID] = node
lock.Unlock()
case discover.QueryMagicV2:
var pkt discover.QueryV2
err := pkt.UnmarshalXDR(buf)
if err != nil {
log.Println("QueryV2 Unmarshal:", err)
log.Println(hex.Dump(buf))
continue
}
if debug {
log.Printf("<- %v %#v", addr, pkt)
}
lock.Lock()
node, ok := nodes[pkt.NodeID]
queries++
lock.Unlock()
if ok && len(node.Addresses) > 0 {
pkt := discover.AnnounceV2{
Magic: discover.AnnouncementMagicV2,
NodeID: pkt.NodeID,
}
for _, addr := range node.Addresses {
pkt.Addresses = append(pkt.Addresses, discover.Address{IP: addr.IP, Port: addr.Port})
}
if debug {
log.Printf("-> %v %#v", addr, pkt)
}
tb := pkt.MarshalXDR()
_, _, err = conn.WriteMsgUDP(tb, nil, addr)
if err != nil {
log.Println("QueryV2 response write:", err)
}
lock.Lock()
answered++
lock.Unlock()
} }
} }
} }