Ensure sensible node config on load (fixes #143)

This commit is contained in:
Jakob Borg 2014-04-22 11:46:08 +02:00
parent e0e16c371f
commit d53b193e09
3 changed files with 72 additions and 12 deletions

View File

@ -3,6 +3,7 @@ package main
import ( import (
"encoding/xml" "encoding/xml"
"io" "io"
"os"
"reflect" "reflect"
"sort" "sort"
"strconv" "strconv"
@ -153,7 +154,7 @@ func uniqueStrings(ss []string) []string {
return us return us
} }
func readConfigXML(rd io.Reader) (Configuration, error) { func readConfigXML(rd io.Reader, myID string) (Configuration, error) {
var cfg Configuration var cfg Configuration
setDefaults(&cfg) setDefaults(&cfg)
@ -169,6 +170,7 @@ func readConfigXML(rd io.Reader) (Configuration, error) {
cfg.Options.ListenAddress = uniqueStrings(cfg.Options.ListenAddress) cfg.Options.ListenAddress = uniqueStrings(cfg.Options.ListenAddress)
// Check for missing or duplicate repository ID:s
var seenRepos = map[string]bool{} var seenRepos = map[string]bool{}
for i := range cfg.Repositories { for i := range cfg.Repositories {
if cfg.Repositories[i].ID == "" { if cfg.Repositories[i].ID == "" {
@ -182,10 +184,12 @@ func readConfigXML(rd io.Reader) (Configuration, error) {
seenRepos[id] = true seenRepos[id] = true
} }
// Upgrade to v2 configuration if appropriate
if cfg.Version == 1 { if cfg.Version == 1 {
convertV1V2(&cfg) convertV1V2(&cfg)
} }
// Hash old cleartext passwords
if len(cfg.GUI.Password) > 0 && cfg.GUI.Password[0] != '$' { if len(cfg.GUI.Password) > 0 && cfg.GUI.Password[0] != '$' {
hash, err := bcrypt.GenerateFromPassword([]byte(cfg.GUI.Password), 0) hash, err := bcrypt.GenerateFromPassword([]byte(cfg.GUI.Password), 0)
if err != nil { if err != nil {
@ -195,6 +199,20 @@ func readConfigXML(rd io.Reader) (Configuration, error) {
} }
} }
// Ensure this node is present in all relevant places
cfg.Nodes = ensureNodePresent(cfg.Nodes, myID)
for i := range cfg.Repositories {
cfg.Repositories[i].Nodes = ensureNodePresent(cfg.Repositories[i].Nodes, myID)
}
// An empty address list is equivalent to a single "dynamic" entry
for i := range cfg.Nodes {
n := &cfg.Nodes[i]
if len(n.Addresses) == 0 || len(n.Addresses) == 1 && n.Addresses[0] == "" {
n.Addresses = []string{"dynamic"}
}
}
return cfg, err return cfg, err
} }
@ -241,7 +259,7 @@ func (l NodeConfigurationList) Len() int {
return len(l) return len(l)
} }
func cleanNodeList(nodes []NodeConfiguration, myID string) []NodeConfiguration { func ensureNodePresent(nodes []NodeConfiguration, myID string) []NodeConfiguration {
var myIDExists bool var myIDExists bool
for _, node := range nodes { for _, node := range nodes {
if node.NodeID == myID { if node.NodeID == myID {
@ -251,10 +269,10 @@ func cleanNodeList(nodes []NodeConfiguration, myID string) []NodeConfiguration {
} }
if !myIDExists { if !myIDExists {
name, _ := os.Hostname()
nodes = append(nodes, NodeConfiguration{ nodes = append(nodes, NodeConfiguration{
NodeID: myID, NodeID: myID,
Addresses: []string{"dynamic"}, Name: name,
Name: "",
}) })
} }

View File

@ -22,7 +22,7 @@ func TestDefaultValues(t *testing.T) {
UPnPEnabled: true, UPnPEnabled: true,
} }
cfg, err := readConfigXML(bytes.NewReader(nil)) cfg, err := readConfigXML(bytes.NewReader(nil), "nodeID")
if err != io.EOF { if err != io.EOF {
t.Error(err) t.Error(err)
} }
@ -65,7 +65,7 @@ func TestNodeConfig(t *testing.T) {
`) `)
for i, data := range [][]byte{v1data, v2data} { for i, data := range [][]byte{v1data, v2data} {
cfg, err := readConfigXML(bytes.NewReader(data)) cfg, err := readConfigXML(bytes.NewReader(data), "node1")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -120,7 +120,7 @@ func TestNoListenAddress(t *testing.T) {
</configuration> </configuration>
`) `)
cfg, err := readConfigXML(bytes.NewReader(data)) cfg, err := readConfigXML(bytes.NewReader(data), "nodeID")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -169,7 +169,7 @@ func TestOverriddenValues(t *testing.T) {
UPnPEnabled: false, UPnPEnabled: false,
} }
cfg, err := readConfigXML(bytes.NewReader(data)) cfg, err := readConfigXML(bytes.NewReader(data), "nodeID")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -178,3 +178,46 @@ func TestOverriddenValues(t *testing.T) {
t.Errorf("Overridden config differs;\n E: %#v\n A: %#v", expected, cfg.Options) t.Errorf("Overridden config differs;\n E: %#v\n A: %#v", expected, cfg.Options)
} }
} }
func TestNodeAddresses(t *testing.T) {
data := []byte(`
<configuration version="2">
<node id="n1">
<address>dynamic</address>
</node>
<node id="n2">
<address></address>
</node>
<node id="n3">
</node>
</configuration>
`)
expected := []NodeConfiguration{
{
NodeID: "n1",
Addresses: []string{"dynamic"},
},
{
NodeID: "n2",
Addresses: []string{"dynamic"},
},
{
NodeID: "n3",
Addresses: []string{"dynamic"},
},
{
NodeID: "n4",
Addresses: []string{"dynamic"},
},
}
cfg, err := readConfigXML(bytes.NewReader(data), "n4")
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(cfg.Nodes, expected) {
t.Errorf("Nodes differ;\n E: %#v\n A: %#v", expected, cfg.Nodes)
}
}

View File

@ -137,7 +137,7 @@ func main() {
cf, err := os.Open(cfgFile) cf, err := os.Open(cfgFile)
if err == nil { if err == nil {
// Read config.xml // Read config.xml
cfg, err = readConfigXML(cf) cfg, err = readConfigXML(cf, myID)
if err != nil { if err != nil {
fatalln(err) fatalln(err)
} }
@ -148,7 +148,7 @@ func main() {
infoln("No config file; starting with empty defaults") infoln("No config file; starting with empty defaults")
name, _ := os.Hostname() name, _ := os.Hostname()
cfg, err = readConfigXML(nil) cfg, err = readConfigXML(nil, myID)
cfg.Repositories = []RepositoryConfiguration{ cfg.Repositories = []RepositoryConfiguration{
{ {
ID: "default", ID: "default",
@ -206,7 +206,6 @@ func main() {
m := NewModel(cfg.Options.MaxChangeKbps * 1000) m := NewModel(cfg.Options.MaxChangeKbps * 1000)
for i := range cfg.Repositories { for i := range cfg.Repositories {
cfg.Repositories[i].Nodes = cleanNodeList(cfg.Repositories[i].Nodes, myID)
dir := expandTilde(cfg.Repositories[i].Directory) dir := expandTilde(cfg.Repositories[i].Directory)
ensureDir(dir, -1) ensureDir(dir, -1)
m.AddRepo(cfg.Repositories[i].ID, dir, cfg.Repositories[i].Nodes) m.AddRepo(cfg.Repositories[i].ID, dir, cfg.Repositories[i].Nodes)