mirror of
https://github.com/octoleo/syncthing.git
synced 2024-12-22 10:58:57 +00:00
lib/syncthing: Clean up / refactor LoadOrGenerateCertificate() utility function. (#8025)
LoadOrGenerateCertificate() takes two file path arguments, but then uses the locations package to determine the actual path. Fix that with a minimally invasive change, by using the arguments instead. Factor out GenerateCertificate(). The only caller of this function is cmd/syncthing, which passes the same values, so this is technically a no-op. * lib/tlsutil: Make storing generated certificate optional. Avoid temporary cert and key files in tests, keep cert in memory.
This commit is contained in:
parent
db15e52743
commit
ec8a748514
@ -49,16 +49,13 @@ import (
|
||||
"github.com/syncthing/syncthing/lib/protocol"
|
||||
"github.com/syncthing/syncthing/lib/svcutil"
|
||||
"github.com/syncthing/syncthing/lib/syncthing"
|
||||
"github.com/syncthing/syncthing/lib/tlsutil"
|
||||
"github.com/syncthing/syncthing/lib/upgrade"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
tlsDefaultCommonName = "syncthing"
|
||||
deviceCertLifetimeDays = 20 * 365
|
||||
sigTerm = syscall.Signal(15)
|
||||
sigTerm = syscall.Signal(15)
|
||||
)
|
||||
|
||||
const (
|
||||
@ -442,7 +439,7 @@ func generate(generateDir string, noDefaultFolder bool) error {
|
||||
if err == nil {
|
||||
l.Warnln("Key exists; will not overwrite.")
|
||||
} else {
|
||||
cert, err = tlsutil.NewCertificate(certFile, keyFile, tlsDefaultCommonName, deviceCertLifetimeDays)
|
||||
cert, err = syncthing.GenerateCertificate(certFile, keyFile)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create certificate")
|
||||
}
|
||||
|
@ -1209,15 +1209,9 @@ func TestPrefixMatch(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestShouldRegenerateCertificate(t *testing.T) {
|
||||
dir, err := ioutil.TempDir("", "syncthing-test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// Self signed certificates expiring in less than a month are errored so we
|
||||
// can regenerate in time.
|
||||
crt, err := tlsutil.NewCertificate(filepath.Join(dir, "crt"), filepath.Join(dir, "key"), "foo.example.com", 29)
|
||||
crt, err := tlsutil.NewCertificateInMemory("foo.example.com", 29)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -1226,7 +1220,7 @@ func TestShouldRegenerateCertificate(t *testing.T) {
|
||||
}
|
||||
|
||||
// Certificates with at least 31 days of life left are fine.
|
||||
crt, err = tlsutil.NewCertificate(filepath.Join(dir, "crt"), filepath.Join(dir, "key"), "foo.example.com", 31)
|
||||
crt, err = tlsutil.NewCertificateInMemory("foo.example.com", 31)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -1236,7 +1230,7 @@ func TestShouldRegenerateCertificate(t *testing.T) {
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
// Certificates with too long an expiry time are not allowed on macOS
|
||||
crt, err = tlsutil.NewCertificate(filepath.Join(dir, "crt"), filepath.Join(dir, "key"), "foo.example.com", 1000)
|
||||
crt, err = tlsutil.NewCertificateInMemory("foo.example.com", 1000)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -11,11 +11,9 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -470,21 +468,9 @@ func withConnectionPair(b *testing.B, connUri string, h func(client, server inte
|
||||
}
|
||||
|
||||
func mustGetCert(b *testing.B) tls.Certificate {
|
||||
f1, err := ioutil.TempFile("", "")
|
||||
cert, err := tlsutil.NewCertificateInMemory("bench", 10)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
f1.Close()
|
||||
f2, err := ioutil.TempFile("", "")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
f2.Close()
|
||||
cert, err := tlsutil.NewCertificate(f1.Name(), f2.Name(), "bench", 10)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
_ = os.Remove(f1.Name())
|
||||
_ = os.Remove(f2.Name())
|
||||
return cert
|
||||
}
|
||||
|
@ -107,13 +107,8 @@ func TestGlobalOverHTTP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGlobalOverHTTPS(t *testing.T) {
|
||||
dir, err := ioutil.TempDir("", "syncthing")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Generate a server certificate.
|
||||
cert, err := tlsutil.NewCertificate(dir+"/cert.pem", dir+"/key.pem", "syncthing", 30)
|
||||
cert, err := tlsutil.NewCertificateInMemory("syncthing", 30)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -172,13 +167,8 @@ func TestGlobalOverHTTPS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGlobalAnnounce(t *testing.T) {
|
||||
dir, err := ioutil.TempDir("", "syncthing")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Generate a server certificate.
|
||||
cert, err := tlsutil.NewCertificate(dir+"/cert.pem", dir+"/key.pem", "syncthing", 30)
|
||||
cert, err := tlsutil.NewCertificateInMemory("syncthing", 30)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -9,7 +9,6 @@ package syncthing
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -57,13 +56,7 @@ func TestShortIDCheck(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStartupFail(t *testing.T) {
|
||||
tmpDir, err := ioutil.TempDir("", "syncthing-TestStartupFail-")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cert, err := tlsutil.NewCertificate(filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key"), "syncthing", 365)
|
||||
cert, err := tlsutil.NewCertificateInMemory("syncthing", 365)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -25,22 +25,18 @@ import (
|
||||
)
|
||||
|
||||
func LoadOrGenerateCertificate(certFile, keyFile string) (tls.Certificate, error) {
|
||||
cert, err := tls.LoadX509KeyPair(
|
||||
locations.Get(locations.CertFile),
|
||||
locations.Get(locations.KeyFile),
|
||||
)
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
l.Infof("Generating ECDSA key and certificate for %s...", tlsDefaultCommonName)
|
||||
return tlsutil.NewCertificate(
|
||||
locations.Get(locations.CertFile),
|
||||
locations.Get(locations.KeyFile),
|
||||
tlsDefaultCommonName,
|
||||
deviceCertLifetimeDays,
|
||||
)
|
||||
return GenerateCertificate(certFile, keyFile)
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func GenerateCertificate(certFile, keyFile string) (tls.Certificate, error) {
|
||||
l.Infof("Generating ECDSA key and certificate for %s...", tlsDefaultCommonName)
|
||||
return tlsutil.NewCertificate(certFile, keyFile, tlsDefaultCommonName, deviceCertLifetimeDays)
|
||||
}
|
||||
|
||||
func DefaultConfig(path string, myID protocol.DeviceID, evLogger events.Logger, noDefaultFolder bool) (config.Wrapper, error) {
|
||||
newCfg, err := config.NewWithFreePorts(myID)
|
||||
if err != nil {
|
||||
|
@ -86,11 +86,11 @@ func SecureDefaultWithTLS12() *tls.Config {
|
||||
}
|
||||
}
|
||||
|
||||
// NewCertificate generates and returns a new TLS certificate.
|
||||
func NewCertificate(certFile, keyFile, commonName string, lifetimeDays int) (tls.Certificate, error) {
|
||||
// generateCertificate generates a PEM formatted key pair and self-signed certificate in memory.
|
||||
func generateCertificate(commonName string, lifetimeDays int) (*pem.Block, *pem.Block, error) {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, errors.Wrap(err, "generate key")
|
||||
return nil, nil, errors.Wrap(err, "generate key")
|
||||
}
|
||||
|
||||
notBefore := time.Now().Truncate(24 * time.Hour)
|
||||
@ -117,19 +117,33 @@ func NewCertificate(certFile, keyFile, commonName string, lifetimeDays int) (tls
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, errors.Wrap(err, "create cert")
|
||||
return nil, nil, errors.Wrap(err, "create cert")
|
||||
}
|
||||
|
||||
certBlock := &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}
|
||||
keyBlock, err := pemBlockForKey(priv)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "save key")
|
||||
}
|
||||
|
||||
return certBlock, keyBlock, nil
|
||||
}
|
||||
|
||||
// NewCertificate generates and returns a new TLS certificate, saved to the given PEM files.
|
||||
func NewCertificate(certFile, keyFile string, commonName string, lifetimeDays int) (tls.Certificate, error) {
|
||||
certBlock, keyBlock, err := generateCertificate(commonName, lifetimeDays)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
certOut, err := os.Create(certFile)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, errors.Wrap(err, "save cert")
|
||||
}
|
||||
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
if err != nil {
|
||||
if err = pem.Encode(certOut, certBlock); err != nil {
|
||||
return tls.Certificate{}, errors.Wrap(err, "save cert")
|
||||
}
|
||||
err = certOut.Close()
|
||||
if err != nil {
|
||||
if err = certOut.Close(); err != nil {
|
||||
return tls.Certificate{}, errors.Wrap(err, "save cert")
|
||||
}
|
||||
|
||||
@ -137,22 +151,24 @@ func NewCertificate(certFile, keyFile, commonName string, lifetimeDays int) (tls
|
||||
if err != nil {
|
||||
return tls.Certificate{}, errors.Wrap(err, "save key")
|
||||
}
|
||||
|
||||
block, err := pemBlockForKey(priv)
|
||||
if err != nil {
|
||||
if err = pem.Encode(keyOut, keyBlock); err != nil {
|
||||
return tls.Certificate{}, errors.Wrap(err, "save key")
|
||||
}
|
||||
if err = keyOut.Close(); err != nil {
|
||||
return tls.Certificate{}, errors.Wrap(err, "save key")
|
||||
}
|
||||
|
||||
err = pem.Encode(keyOut, block)
|
||||
return tls.X509KeyPair(pem.EncodeToMemory(certBlock), pem.EncodeToMemory(keyBlock))
|
||||
}
|
||||
|
||||
// NewCertificateInMemory generates and returns a new TLS certificate, kept only in memory.
|
||||
func NewCertificateInMemory(commonName string, lifetimeDays int) (tls.Certificate, error) {
|
||||
certBlock, keyBlock, err := generateCertificate(commonName, lifetimeDays)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, errors.Wrap(err, "save key")
|
||||
}
|
||||
err = keyOut.Close()
|
||||
if err != nil {
|
||||
return tls.Certificate{}, errors.Wrap(err, "save key")
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
return tls.LoadX509KeyPair(certFile, keyFile)
|
||||
return tls.X509KeyPair(pem.EncodeToMemory(certBlock), pem.EncodeToMemory(keyBlock))
|
||||
}
|
||||
|
||||
type DowngradingListener struct {
|
||||
|
Loading…
Reference in New Issue
Block a user