diff --git a/go/base/context.go b/go/base/context.go index 0043d07..856a29c 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -99,6 +99,8 @@ type MigrationContext struct { ConfigFile string CliUser string CliPassword string + UseTLS bool + TlsCACertificate string CliMasterUser string CliMasterPassword string @@ -695,6 +697,13 @@ func (this *MigrationContext) ApplyCredentials() { } } +func (this *MigrationContext) SetupTLS() error { + if this.UseTLS { + return this.InspectorConnectionConfig.UseTLS(this.TlsCACertificate) + } + return nil +} + // ReadConfigFile attempts to read the config file, if it exists func (this *MigrationContext) ReadConfigFile() error { this.configMutex.Lock() diff --git a/go/binlog/gomysql_reader.go b/go/binlog/gomysql_reader.go index 794d545..41d1ead 100644 --- a/go/binlog/gomysql_reader.go +++ b/go/binlog/gomysql_reader.go @@ -46,6 +46,7 @@ func NewGoMySQLReader(migrationContext *base.MigrationContext) (binlogReader *Go Port: uint16(binlogReader.connectionConfig.Key.Port), User: binlogReader.connectionConfig.User, Password: binlogReader.connectionConfig.Password, + TLSConfig: binlogReader.connectionConfig.TLSConfig(), UseDecimal: true, } binlogReader.binlogSyncer = replication.NewBinlogSyncer(binlogSyncerConfig) diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 3024702..8bd8c78 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -14,6 +14,7 @@ import ( "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/logic" + _ "github.com/go-sql-driver/mysql" "github.com/outbrain/golib/log" "golang.org/x/crypto/ssh/terminal" @@ -54,6 +55,9 @@ func main() { flag.StringVar(&migrationContext.ConfigFile, "conf", "", "Config file") askPass := flag.Bool("ask-pass", false, "prompt for MySQL password") + flag.BoolVar(&migrationContext.UseTLS, "ssl", false, "Enable SSL encrypted connections to MySQL") + flag.StringVar(&migrationContext.TlsCACertificate, "ssl-ca", "", "CA certificate in PEM format for TLS connections. Requires --ssl") + flag.StringVar(&migrationContext.DatabaseName, "database", "", "database name (mandatory)") flag.StringVar(&migrationContext.OriginalTableName, "table", "", "table name (mandatory)") flag.StringVar(&migrationContext.AlterStatement, "alter", "", "alter statement (mandatory)") @@ -194,6 +198,9 @@ func main() { if migrationContext.CliMasterPassword != "" && migrationContext.AssumeMasterHostname == "" { log.Fatalf("--master-password requires --assume-master-host") } + if migrationContext.TlsCACertificate != "" && !migrationContext.UseTLS { + log.Fatalf("--ssl-ca requires --ssl") + } if *replicationLagQuery != "" { log.Warningf("--replication-lag-query is deprecated") } @@ -238,6 +245,7 @@ func main() { migrationContext.SetThrottleHTTP(*throttleHTTP) migrationContext.SetDefaultNumRetries(*defaultRetries) migrationContext.ApplyCredentials() + migrationContext.SetupTLS() if err := migrationContext.SetCutOverLockTimeoutSeconds(*cutOverLockTimeoutSeconds); err != nil { log.Errore(err) } diff --git a/go/logic/applier.go b/go/logic/applier.go index 140b1a4..b3de0e1 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -73,7 +73,7 @@ func (this *Applier) InitDBConnections() (err error) { if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil { return err } - singletonApplierUri := fmt.Sprintf("%s?timeout=0", applierUri) + singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri) if this.singletonDB, _, err = mysql.GetDB(this.migrationContext.Uuid, singletonApplierUri); err != nil { return err } diff --git a/go/mysql/connection.go b/go/mysql/connection.go index c9c75f2..bfb445a 100644 --- a/go/mysql/connection.go +++ b/go/mysql/connection.go @@ -6,8 +6,14 @@ package mysql import ( + "crypto/tls" + "crypto/x509" + "errors" "fmt" + "io/ioutil" "net" + + "github.com/go-sql-driver/mysql" ) // ConnectionConfig is the minimal configuration required to connect to a MySQL server @@ -16,6 +22,7 @@ type ConnectionConfig struct { User string Password string ImpliedKey *InstanceKey + tlsConfig *tls.Config } func NewConnectionConfig() *ConnectionConfig { @@ -29,9 +36,10 @@ func NewConnectionConfig() *ConnectionConfig { // DuplicateCredentials creates a new connection config with given key and with same credentials as this config func (this *ConnectionConfig) DuplicateCredentials(key InstanceKey) *ConnectionConfig { config := &ConnectionConfig{ - Key: key, - User: this.User, - Password: this.Password, + Key: key, + User: this.User, + Password: this.Password, + tlsConfig: this.tlsConfig, } config.ImpliedKey = &config.Key return config @@ -42,13 +50,42 @@ func (this *ConnectionConfig) Duplicate() *ConnectionConfig { } func (this *ConnectionConfig) String() string { - return fmt.Sprintf("%s, user=%s", this.Key.DisplayString(), this.User) + return fmt.Sprintf("%s, user=%s, usingTLS=%t", this.Key.DisplayString(), this.User, this.tlsConfig != nil) } func (this *ConnectionConfig) Equals(other *ConnectionConfig) bool { return this.Key.Equals(&other.Key) || this.ImpliedKey.Equals(other.ImpliedKey) } +func (this *ConnectionConfig) UseTLS(caCertificatePath string) error { + skipVerify := caCertificatePath == "" + var rootCertPool *x509.CertPool + if !skipVerify { + rootCertPool = x509.NewCertPool() + pem, err := ioutil.ReadFile(caCertificatePath) + if err != nil { + return err + } + if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { + return errors.New("could not add ca certificate to cert pool") + } + } + + this.tlsConfig = &tls.Config{ + RootCAs: rootCertPool, + InsecureSkipVerify: skipVerify, + } + + if err := mysql.RegisterTLSConfig(this.Key.StringCode(), this.tlsConfig); err != nil { + return err + } + return nil +} + +func (this *ConnectionConfig) TLSConfig() *tls.Config { + return this.tlsConfig +} + func (this *ConnectionConfig) GetDBUri(databaseName string) string { hostname := this.Key.Hostname var ip = net.ParseIP(hostname) @@ -57,5 +94,9 @@ func (this *ConnectionConfig) GetDBUri(databaseName string) string { hostname = fmt.Sprintf("[%s]", hostname) } interpolateParams := true - return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?interpolateParams=%t&autocommit=true&charset=utf8mb4,utf8,latin1", this.User, this.Password, hostname, this.Key.Port, databaseName, interpolateParams) + tlsOption := "false" + if this.tlsConfig != nil { + tlsOption = this.Key.StringCode() + } + return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?interpolateParams=%t&autocommit=true&charset=utf8mb4,utf8,latin1&tls=%s", this.User, this.Password, hostname, this.Key.Port, databaseName, interpolateParams, tlsOption) } diff --git a/go/mysql/connection_test.go b/go/mysql/connection_test.go index 19badf3..3a6cf29 100644 --- a/go/mysql/connection_test.go +++ b/go/mysql/connection_test.go @@ -6,6 +6,7 @@ package mysql import ( + "crypto/tls" "testing" "github.com/outbrain/golib/log" @@ -31,6 +32,10 @@ func TestDuplicateCredentials(t *testing.T) { c.Key = InstanceKey{Hostname: "myhost", Port: 3306} c.User = "gromit" c.Password = "penguin" + c.tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + ServerName: "feathers", + } dup := c.DuplicateCredentials(InstanceKey{Hostname: "otherhost", Port: 3310}) test.S(t).ExpectEquals(dup.Key.Hostname, "otherhost") @@ -39,6 +44,7 @@ func TestDuplicateCredentials(t *testing.T) { test.S(t).ExpectEquals(dup.ImpliedKey.Port, 3310) test.S(t).ExpectEquals(dup.User, "gromit") test.S(t).ExpectEquals(dup.Password, "penguin") + test.S(t).ExpectEquals(dup.tlsConfig, c.tlsConfig) } func TestDuplicate(t *testing.T) { @@ -63,5 +69,16 @@ func TestGetDBUri(t *testing.T) { c.Password = "penguin" uri := c.GetDBUri("test") - test.S(t).ExpectEquals(uri, "gromit:penguin@tcp(myhost:3306)/test?interpolateParams=true&autocommit=true&charset=utf8mb4,utf8,latin1") + test.S(t).ExpectEquals(uri, "gromit:penguin@tcp(myhost:3306)/test?interpolateParams=true&autocommit=true&charset=utf8mb4,utf8,latin1&tls=false") +} + +func TestGetDBUriWithTLSSetup(t *testing.T) { + c := NewConnectionConfig() + c.Key = InstanceKey{Hostname: "myhost", Port: 3306} + c.User = "gromit" + c.Password = "penguin" + c.tlsConfig = &tls.Config{} + + uri := c.GetDBUri("test") + test.S(t).ExpectEquals(uri, "gromit:penguin@tcp(myhost:3306)/test?interpolateParams=true&autocommit=true&charset=utf8mb4,utf8,latin1&tls=myhost:3306") }