mirror of
https://github.com/octoleo/syncthing.git
synced 2024-11-06 05:17:49 +00:00
327 lines
8.0 KiB
Go
327 lines
8.0 KiB
Go
// Copyright (C) 2018 The Syncthing Authors.
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
|
// You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
|
|
package main
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"encoding/binary"
|
|
"fmt"
|
|
io "io"
|
|
"log"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/syncthing/syncthing/lib/protocol"
|
|
)
|
|
|
|
const replicationReadTimeout = time.Minute
|
|
const replicationHeartbeatInterval = time.Second * 30
|
|
|
|
type replicator interface {
|
|
send(key string, addrs []DatabaseAddress, seen int64)
|
|
}
|
|
|
|
// a replicationSender tries to connect to the remote address and provide
|
|
// them with a feed of replication updates.
|
|
type replicationSender struct {
|
|
dst string
|
|
cert tls.Certificate // our certificate
|
|
allowedIDs []protocol.DeviceID
|
|
outbox chan ReplicationRecord
|
|
stop chan struct{}
|
|
}
|
|
|
|
func newReplicationSender(dst string, cert tls.Certificate, allowedIDs []protocol.DeviceID) *replicationSender {
|
|
return &replicationSender{
|
|
dst: dst,
|
|
cert: cert,
|
|
allowedIDs: allowedIDs,
|
|
outbox: make(chan ReplicationRecord, replicationOutboxSize),
|
|
stop: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
func (s *replicationSender) Serve() {
|
|
// Sleep a little at startup. Peers often restart at the same time, and
|
|
// this avoid the service failing and entering backoff state
|
|
// unnecessarily, while also reducing the reconnect rate to something
|
|
// reasonable by default.
|
|
time.Sleep(2 * time.Second)
|
|
|
|
tlsCfg := &tls.Config{
|
|
Certificates: []tls.Certificate{s.cert},
|
|
MinVersion: tls.VersionTLS12,
|
|
InsecureSkipVerify: true,
|
|
}
|
|
|
|
// Dial the TLS connection.
|
|
conn, err := tls.Dial("tcp", s.dst, tlsCfg)
|
|
if err != nil {
|
|
log.Println("Replication connect:", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
conn.SetWriteDeadline(time.Now().Add(time.Second))
|
|
conn.Close()
|
|
}()
|
|
|
|
// Get the other side device ID.
|
|
remoteID, err := deviceID(conn)
|
|
if err != nil {
|
|
log.Println("Replication connect:", err)
|
|
return
|
|
}
|
|
|
|
// Verify it's in the set of allowed device IDs.
|
|
if !deviceIDIn(remoteID, s.allowedIDs) {
|
|
log.Println("Replication connect: unexpected device ID:", remoteID)
|
|
return
|
|
}
|
|
|
|
heartBeatTicker := time.NewTicker(replicationHeartbeatInterval)
|
|
defer heartBeatTicker.Stop()
|
|
|
|
// Send records.
|
|
buf := make([]byte, 1024)
|
|
for {
|
|
select {
|
|
case <-heartBeatTicker.C:
|
|
if len(s.outbox) > 0 {
|
|
// No need to send heartbeats if there are events/prevrious
|
|
// heartbeats to send, they will keep the connection alive.
|
|
continue
|
|
}
|
|
// Empty replication message is the heartbeat:
|
|
s.outbox <- ReplicationRecord{}
|
|
|
|
case rec := <-s.outbox:
|
|
// Buffer must hold record plus four bytes for size
|
|
size := rec.Size()
|
|
if len(buf) < size+4 {
|
|
buf = make([]byte, size+4)
|
|
}
|
|
|
|
// Record comes after the four bytes size
|
|
n, err := rec.MarshalTo(buf[4:])
|
|
if err != nil {
|
|
// odd to get an error here, but we haven't sent anything
|
|
// yet so it's not fatal
|
|
replicationSendsTotal.WithLabelValues("error").Inc()
|
|
log.Println("Replication marshal:", err)
|
|
continue
|
|
}
|
|
binary.BigEndian.PutUint32(buf, uint32(n))
|
|
|
|
// Send
|
|
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
if _, err := conn.Write(buf[:4+n]); err != nil {
|
|
replicationSendsTotal.WithLabelValues("error").Inc()
|
|
log.Println("Replication write:", err)
|
|
// Yes, we are loosing the replication event here.
|
|
return
|
|
}
|
|
replicationSendsTotal.WithLabelValues("success").Inc()
|
|
|
|
case <-s.stop:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *replicationSender) Stop() {
|
|
close(s.stop)
|
|
}
|
|
|
|
func (s *replicationSender) String() string {
|
|
return fmt.Sprintf("replicationSender(%q)", s.dst)
|
|
}
|
|
|
|
func (s *replicationSender) send(key string, ps []DatabaseAddress, seen int64) {
|
|
item := ReplicationRecord{
|
|
Key: key,
|
|
Addresses: ps,
|
|
}
|
|
|
|
// The send should never block. The inbox is suitably buffered for at
|
|
// least a few seconds of stalls, which shouldn't happen in practice.
|
|
select {
|
|
case s.outbox <- item:
|
|
default:
|
|
replicationSendsTotal.WithLabelValues("drop").Inc()
|
|
}
|
|
}
|
|
|
|
// a replicationMultiplexer sends to multiple replicators
|
|
type replicationMultiplexer []replicator
|
|
|
|
func (m replicationMultiplexer) send(key string, ps []DatabaseAddress, seen int64) {
|
|
for _, s := range m {
|
|
// each send is nonblocking
|
|
s.send(key, ps, seen)
|
|
}
|
|
}
|
|
|
|
// replicationListener accepts incoming connections and reads replication
|
|
// items from them. Incoming items are applied to the KV store.
|
|
type replicationListener struct {
|
|
addr string
|
|
cert tls.Certificate
|
|
allowedIDs []protocol.DeviceID
|
|
db database
|
|
stop chan struct{}
|
|
}
|
|
|
|
func newReplicationListener(addr string, cert tls.Certificate, allowedIDs []protocol.DeviceID, db database) *replicationListener {
|
|
return &replicationListener{
|
|
addr: addr,
|
|
cert: cert,
|
|
allowedIDs: allowedIDs,
|
|
db: db,
|
|
stop: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
func (l *replicationListener) Serve() {
|
|
tlsCfg := &tls.Config{
|
|
Certificates: []tls.Certificate{l.cert},
|
|
ClientAuth: tls.RequestClientCert,
|
|
MinVersion: tls.VersionTLS12,
|
|
InsecureSkipVerify: true,
|
|
}
|
|
|
|
lst, err := tls.Listen("tcp", l.addr, tlsCfg)
|
|
if err != nil {
|
|
log.Println("Replication listen:", err)
|
|
return
|
|
}
|
|
defer lst.Close()
|
|
|
|
for {
|
|
select {
|
|
case <-l.stop:
|
|
return
|
|
default:
|
|
}
|
|
|
|
// Accept a connection
|
|
conn, err := lst.Accept()
|
|
if err != nil {
|
|
log.Println("Replication accept:", err)
|
|
return
|
|
}
|
|
|
|
// Figure out the other side device ID
|
|
remoteID, err := deviceID(conn.(*tls.Conn))
|
|
if err != nil {
|
|
log.Println("Replication accept:", err)
|
|
conn.SetWriteDeadline(time.Now().Add(time.Second))
|
|
conn.Close()
|
|
continue
|
|
}
|
|
|
|
// Verify it is in the set of allowed device IDs
|
|
if !deviceIDIn(remoteID, l.allowedIDs) {
|
|
log.Println("Replication accept: unexpected device ID:", remoteID)
|
|
conn.SetWriteDeadline(time.Now().Add(time.Second))
|
|
conn.Close()
|
|
continue
|
|
}
|
|
|
|
go l.handle(conn)
|
|
}
|
|
}
|
|
|
|
func (l *replicationListener) Stop() {
|
|
close(l.stop)
|
|
}
|
|
|
|
func (l *replicationListener) String() string {
|
|
return fmt.Sprintf("replicationListener(%q)", l.addr)
|
|
}
|
|
|
|
func (l *replicationListener) handle(conn net.Conn) {
|
|
defer func() {
|
|
conn.SetWriteDeadline(time.Now().Add(time.Second))
|
|
conn.Close()
|
|
}()
|
|
|
|
buf := make([]byte, 1024)
|
|
|
|
for {
|
|
select {
|
|
case <-l.stop:
|
|
return
|
|
default:
|
|
}
|
|
|
|
conn.SetReadDeadline(time.Now().Add(replicationReadTimeout))
|
|
|
|
// First four bytes are the size
|
|
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
|
|
log.Println("Replication read size:", err)
|
|
replicationRecvsTotal.WithLabelValues("error").Inc()
|
|
return
|
|
}
|
|
|
|
// Read the rest of the record
|
|
size := int(binary.BigEndian.Uint32(buf[:4]))
|
|
if len(buf) < size {
|
|
buf = make([]byte, size)
|
|
}
|
|
|
|
if size == 0 {
|
|
// Heartbeat, ignore
|
|
continue
|
|
}
|
|
|
|
if _, err := io.ReadFull(conn, buf[:size]); err != nil {
|
|
log.Println("Replication read record:", err)
|
|
replicationRecvsTotal.WithLabelValues("error").Inc()
|
|
return
|
|
}
|
|
|
|
// Unmarshal
|
|
var rec ReplicationRecord
|
|
if err := rec.Unmarshal(buf[:size]); err != nil {
|
|
log.Println("Replication unmarshal:", err)
|
|
replicationRecvsTotal.WithLabelValues("error").Inc()
|
|
continue
|
|
}
|
|
|
|
// Store
|
|
l.db.merge(rec.Key, rec.Addresses, rec.Seen)
|
|
replicationRecvsTotal.WithLabelValues("success").Inc()
|
|
}
|
|
}
|
|
|
|
func deviceID(conn *tls.Conn) (protocol.DeviceID, error) {
|
|
// Handshake may not be complete on the server side yet, which we need
|
|
// to get the client certificate.
|
|
if !conn.ConnectionState().HandshakeComplete {
|
|
if err := conn.Handshake(); err != nil {
|
|
return protocol.DeviceID{}, err
|
|
}
|
|
}
|
|
|
|
// We expect exactly one certificate.
|
|
certs := conn.ConnectionState().PeerCertificates
|
|
if len(certs) != 1 {
|
|
return protocol.DeviceID{}, fmt.Errorf("unexpected number of certificates (%d != 1)", len(certs))
|
|
}
|
|
|
|
return protocol.NewDeviceID(certs[0].Raw), nil
|
|
}
|
|
|
|
func deviceIDIn(id protocol.DeviceID, ids []protocol.DeviceID) bool {
|
|
for _, candidate := range ids {
|
|
if id == candidate {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|