lib/syncthing: Clean up / refactor LoadOrGenerateCertificate() utility function. (#8025)

LoadOrGenerateCertificate() takes two file path arguments, but then
uses the locations package to determine the actual path.  Fix that
with a minimally invasive change, by using the arguments instead.
Factor out GenerateCertificate().

The only caller of this function is cmd/syncthing, which passes the
same values, so this is technically a no-op.

* lib/tlsutil: Make storing generated certificate optional.  Avoid
  temporary cert and key files in tests, keep cert in memory.
This commit is contained in:
André Colomb 2021-11-07 23:59:48 +01:00 committed by GitHub
parent db15e52743
commit ec8a748514
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 50 additions and 78 deletions

View File

@ -49,15 +49,12 @@ import (
"github.com/syncthing/syncthing/lib/protocol"
"github.com/syncthing/syncthing/lib/svcutil"
"github.com/syncthing/syncthing/lib/syncthing"
"github.com/syncthing/syncthing/lib/tlsutil"
"github.com/syncthing/syncthing/lib/upgrade"
"github.com/pkg/errors"
)
const (
tlsDefaultCommonName = "syncthing"
deviceCertLifetimeDays = 20 * 365
sigTerm = syscall.Signal(15)
)
@ -442,7 +439,7 @@ func generate(generateDir string, noDefaultFolder bool) error {
if err == nil {
l.Warnln("Key exists; will not overwrite.")
} else {
cert, err = tlsutil.NewCertificate(certFile, keyFile, tlsDefaultCommonName, deviceCertLifetimeDays)
cert, err = syncthing.GenerateCertificate(certFile, keyFile)
if err != nil {
return errors.Wrap(err, "create certificate")
}

View File

@ -1209,15 +1209,9 @@ func TestPrefixMatch(t *testing.T) {
}
func TestShouldRegenerateCertificate(t *testing.T) {
dir, err := ioutil.TempDir("", "syncthing-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
// Self signed certificates expiring in less than a month are errored so we
// can regenerate in time.
crt, err := tlsutil.NewCertificate(filepath.Join(dir, "crt"), filepath.Join(dir, "key"), "foo.example.com", 29)
crt, err := tlsutil.NewCertificateInMemory("foo.example.com", 29)
if err != nil {
t.Fatal(err)
}
@ -1226,7 +1220,7 @@ func TestShouldRegenerateCertificate(t *testing.T) {
}
// Certificates with at least 31 days of life left are fine.
crt, err = tlsutil.NewCertificate(filepath.Join(dir, "crt"), filepath.Join(dir, "key"), "foo.example.com", 31)
crt, err = tlsutil.NewCertificateInMemory("foo.example.com", 31)
if err != nil {
t.Fatal(err)
}
@ -1236,7 +1230,7 @@ func TestShouldRegenerateCertificate(t *testing.T) {
if runtime.GOOS == "darwin" {
// Certificates with too long an expiry time are not allowed on macOS
crt, err = tlsutil.NewCertificate(filepath.Join(dir, "crt"), filepath.Join(dir, "key"), "foo.example.com", 1000)
crt, err = tlsutil.NewCertificateInMemory("foo.example.com", 1000)
if err != nil {
t.Fatal(err)
}

View File

@ -11,11 +11,9 @@ import (
"crypto/tls"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net"
"net/url"
"os"
"strings"
"testing"
"time"
@ -470,21 +468,9 @@ func withConnectionPair(b *testing.B, connUri string, h func(client, server inte
}
func mustGetCert(b *testing.B) tls.Certificate {
f1, err := ioutil.TempFile("", "")
cert, err := tlsutil.NewCertificateInMemory("bench", 10)
if err != nil {
b.Fatal(err)
}
f1.Close()
f2, err := ioutil.TempFile("", "")
if err != nil {
b.Fatal(err)
}
f2.Close()
cert, err := tlsutil.NewCertificate(f1.Name(), f2.Name(), "bench", 10)
if err != nil {
b.Fatal(err)
}
_ = os.Remove(f1.Name())
_ = os.Remove(f2.Name())
return cert
}

View File

@ -107,13 +107,8 @@ func TestGlobalOverHTTP(t *testing.T) {
}
func TestGlobalOverHTTPS(t *testing.T) {
dir, err := ioutil.TempDir("", "syncthing")
if err != nil {
t.Fatal(err)
}
// Generate a server certificate.
cert, err := tlsutil.NewCertificate(dir+"/cert.pem", dir+"/key.pem", "syncthing", 30)
cert, err := tlsutil.NewCertificateInMemory("syncthing", 30)
if err != nil {
t.Fatal(err)
}
@ -172,13 +167,8 @@ func TestGlobalOverHTTPS(t *testing.T) {
}
func TestGlobalAnnounce(t *testing.T) {
dir, err := ioutil.TempDir("", "syncthing")
if err != nil {
t.Fatal(err)
}
// Generate a server certificate.
cert, err := tlsutil.NewCertificate(dir+"/cert.pem", dir+"/key.pem", "syncthing", 30)
cert, err := tlsutil.NewCertificateInMemory("syncthing", 30)
if err != nil {
t.Fatal(err)
}

View File

@ -9,7 +9,6 @@ package syncthing
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"time"
@ -57,13 +56,7 @@ func TestShortIDCheck(t *testing.T) {
}
func TestStartupFail(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "syncthing-TestStartupFail-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
cert, err := tlsutil.NewCertificate(filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key"), "syncthing", 365)
cert, err := tlsutil.NewCertificateInMemory("syncthing", 365)
if err != nil {
t.Fatal(err)
}

View File

@ -25,22 +25,18 @@ import (
)
func LoadOrGenerateCertificate(certFile, keyFile string) (tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(
locations.Get(locations.CertFile),
locations.Get(locations.KeyFile),
)
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
l.Infof("Generating ECDSA key and certificate for %s...", tlsDefaultCommonName)
return tlsutil.NewCertificate(
locations.Get(locations.CertFile),
locations.Get(locations.KeyFile),
tlsDefaultCommonName,
deviceCertLifetimeDays,
)
return GenerateCertificate(certFile, keyFile)
}
return cert, nil
}
func GenerateCertificate(certFile, keyFile string) (tls.Certificate, error) {
l.Infof("Generating ECDSA key and certificate for %s...", tlsDefaultCommonName)
return tlsutil.NewCertificate(certFile, keyFile, tlsDefaultCommonName, deviceCertLifetimeDays)
}
func DefaultConfig(path string, myID protocol.DeviceID, evLogger events.Logger, noDefaultFolder bool) (config.Wrapper, error) {
newCfg, err := config.NewWithFreePorts(myID)
if err != nil {

View File

@ -86,11 +86,11 @@ func SecureDefaultWithTLS12() *tls.Config {
}
}
// NewCertificate generates and returns a new TLS certificate.
func NewCertificate(certFile, keyFile, commonName string, lifetimeDays int) (tls.Certificate, error) {
// generateCertificate generates a PEM formatted key pair and self-signed certificate in memory.
func generateCertificate(commonName string, lifetimeDays int) (*pem.Block, *pem.Block, error) {
priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
return tls.Certificate{}, errors.Wrap(err, "generate key")
return nil, nil, errors.Wrap(err, "generate key")
}
notBefore := time.Now().Truncate(24 * time.Hour)
@ -117,19 +117,33 @@ func NewCertificate(certFile, keyFile, commonName string, lifetimeDays int) (tls
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
if err != nil {
return tls.Certificate{}, errors.Wrap(err, "create cert")
return nil, nil, errors.Wrap(err, "create cert")
}
certBlock := &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}
keyBlock, err := pemBlockForKey(priv)
if err != nil {
return nil, nil, errors.Wrap(err, "save key")
}
return certBlock, keyBlock, nil
}
// NewCertificate generates and returns a new TLS certificate, saved to the given PEM files.
func NewCertificate(certFile, keyFile string, commonName string, lifetimeDays int) (tls.Certificate, error) {
certBlock, keyBlock, err := generateCertificate(commonName, lifetimeDays)
if err != nil {
return tls.Certificate{}, err
}
certOut, err := os.Create(certFile)
if err != nil {
return tls.Certificate{}, errors.Wrap(err, "save cert")
}
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
if err != nil {
if err = pem.Encode(certOut, certBlock); err != nil {
return tls.Certificate{}, errors.Wrap(err, "save cert")
}
err = certOut.Close()
if err != nil {
if err = certOut.Close(); err != nil {
return tls.Certificate{}, errors.Wrap(err, "save cert")
}
@ -137,22 +151,24 @@ func NewCertificate(certFile, keyFile, commonName string, lifetimeDays int) (tls
if err != nil {
return tls.Certificate{}, errors.Wrap(err, "save key")
}
block, err := pemBlockForKey(priv)
if err != nil {
if err = pem.Encode(keyOut, keyBlock); err != nil {
return tls.Certificate{}, errors.Wrap(err, "save key")
}
if err = keyOut.Close(); err != nil {
return tls.Certificate{}, errors.Wrap(err, "save key")
}
err = pem.Encode(keyOut, block)
if err != nil {
return tls.Certificate{}, errors.Wrap(err, "save key")
}
err = keyOut.Close()
if err != nil {
return tls.Certificate{}, errors.Wrap(err, "save key")
return tls.X509KeyPair(pem.EncodeToMemory(certBlock), pem.EncodeToMemory(keyBlock))
}
return tls.LoadX509KeyPair(certFile, keyFile)
// NewCertificateInMemory generates and returns a new TLS certificate, kept only in memory.
func NewCertificateInMemory(commonName string, lifetimeDays int) (tls.Certificate, error) {
certBlock, keyBlock, err := generateCertificate(commonName, lifetimeDays)
if err != nil {
return tls.Certificate{}, err
}
return tls.X509KeyPair(pem.EncodeToMemory(certBlock), pem.EncodeToMemory(keyBlock))
}
type DowngradingListener struct {