diff --git a/cmd/discosrv/main.go b/cmd/discosrv/main.go index b1d3492c1..1ac8ae170 100644 --- a/cmd/discosrv/main.go +++ b/cmd/discosrv/main.go @@ -1,8 +1,12 @@ package main import ( + "encoding/binary" + "encoding/hex" + "flag" "log" "net" + "os" "sync" "time" @@ -10,19 +14,38 @@ import ( ) type Node struct { - IP []byte - Port uint16 - Updated time.Time + Addresses []Address + Updated time.Time +} + +type Address struct { + IP []byte + Port uint16 } var ( - nodes = make(map[string]Node) - lock sync.Mutex - queries = 0 + nodes = make(map[string]Node) + lock sync.Mutex + queries = 0 + answered = 0 ) func main() { - addr, _ := net.ResolveUDPAddr("udp", ":22025") + var debug bool + var listen string + var timestamp bool + + flag.StringVar(&listen, "listen", ":22025", "Listen address") + flag.BoolVar(&debug, "debug", false, "Enable debug output") + flag.BoolVar(×tamp, "timestamp", true, "Timestamp the log output") + flag.Parse() + + log.SetOutput(os.Stdout) + if !timestamp { + log.SetFlags(0) + } + + addr, _ := net.ResolveUDPAddr("udp", listen) conn, err := net.ListenUDP("udp", addr) if err != nil { panic(err) @@ -41,8 +64,9 @@ func main() { deleted++ } } - log.Printf("Expired %d nodes; %d nodes in registry; %d queries", deleted, len(nodes), queries) + log.Printf("Expired %d nodes; %d nodes in registry; %d queries (%d answered)", deleted, len(nodes), queries, answered) queries = 0 + answered = 0 lock.Unlock() } @@ -50,49 +74,163 @@ func main() { var buf = make([]byte, 1024) for { + buf = buf[:cap(buf)] n, addr, err := conn.ReadFromUDP(buf) if err != nil { panic(err) } - pkt, err := discover.DecodePacket(buf[:n]) - if err != nil { - log.Println("Warning:", err) + if n < 4 { + log.Printf("Received short packet (%d bytes)", n) continue } - switch pkt.Magic { - case 0x20121025: - // Announcement - lock.Lock() + buf = buf[:n] + magic := binary.BigEndian.Uint32(buf) + + switch magic { + case discover.AnnouncementMagicV1: + var pkt discover.AnnounceV1 + err := pkt.UnmarshalXDR(buf) + if err != nil { + log.Println("AnnounceV1 Unmarshal:", err) + log.Println(hex.Dump(buf)) + continue + } + if debug { + log.Printf("<- %v %#v", addr, pkt) + } + ip := addr.IP.To4() if ip == nil { ip = addr.IP.To16() } node := Node{ - IP: ip, - Port: uint16(pkt.Port), + Addresses: []Address{{ + IP: ip, + Port: pkt.Port, + }}, Updated: time.Now(), } - //log.Println("<-", pkt.ID, node) - nodes[pkt.ID] = node - lock.Unlock() - case 0x19760309: - // Query + lock.Lock() - node, ok := nodes[pkt.ID] + nodes[pkt.NodeID] = node + lock.Unlock() + + case discover.QueryMagicV1: + var pkt discover.QueryV1 + err := pkt.UnmarshalXDR(buf) + if err != nil { + log.Println("QueryV1 Unmarshal:", err) + log.Println(hex.Dump(buf)) + continue + } + if debug { + log.Printf("<- %v %#v", addr, pkt) + } + + lock.Lock() + node, ok := nodes[pkt.NodeID] queries++ lock.Unlock() - if ok { - pkt := discover.Packet{ - Magic: 0x20121025, - ID: pkt.ID, - Port: node.Port, - IP: node.IP, + + if ok && len(node.Addresses) > 0 { + pkt := discover.AnnounceV1{ + Magic: discover.AnnouncementMagicV1, + NodeID: pkt.NodeID, + Port: node.Addresses[0].Port, + IP: node.Addresses[0].IP, } - _, _, err = conn.WriteMsgUDP(discover.EncodePacket(pkt), nil, addr) + if debug { + log.Printf("-> %v %#v", addr, pkt) + } + + tb := pkt.MarshalXDR() + _, _, err = conn.WriteMsgUDP(tb, nil, addr) if err != nil { - log.Println("Warning:", err) + log.Println("QueryV1 response write:", err) } + + lock.Lock() + answered++ + lock.Unlock() + } + + case discover.AnnouncementMagicV2: + var pkt discover.AnnounceV2 + err := pkt.UnmarshalXDR(buf) + if err != nil { + log.Println("AnnounceV2 Unmarshal:", err) + log.Println(hex.Dump(buf)) + continue + } + if debug { + log.Printf("<- %v %#v", addr, pkt) + } + + ip := addr.IP.To4() + if ip == nil { + ip = addr.IP.To16() + } + + var addrs []Address + for _, addr := range pkt.Addresses { + tip := addr.IP + if len(tip) == 0 { + tip = ip + } + addrs = append(addrs, Address{ + IP: tip, + Port: addr.Port, + }) + } + + node := Node{ + Addresses: addrs, + Updated: time.Now(), + } + + lock.Lock() + nodes[pkt.NodeID] = node + lock.Unlock() + + case discover.QueryMagicV2: + var pkt discover.QueryV2 + err := pkt.UnmarshalXDR(buf) + if err != nil { + log.Println("QueryV2 Unmarshal:", err) + log.Println(hex.Dump(buf)) + continue + } + if debug { + log.Printf("<- %v %#v", addr, pkt) + } + + lock.Lock() + node, ok := nodes[pkt.NodeID] + queries++ + lock.Unlock() + + if ok && len(node.Addresses) > 0 { + pkt := discover.AnnounceV2{ + Magic: discover.AnnouncementMagicV2, + NodeID: pkt.NodeID, + } + for _, addr := range node.Addresses { + pkt.Addresses = append(pkt.Addresses, discover.Address{IP: addr.IP, Port: addr.Port}) + } + if debug { + log.Printf("-> %v %#v", addr, pkt) + } + + tb := pkt.MarshalXDR() + _, _, err = conn.WriteMsgUDP(tb, nil, addr) + if err != nil { + log.Println("QueryV2 response write:", err) + } + + lock.Lock() + answered++ + lock.Unlock() } } }