134 lines
3.3 KiB
Go
134 lines
3.3 KiB
Go
|
package server
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"crypto/rand"
|
||
|
"crypto/rsa"
|
||
|
"crypto/sha1"
|
||
|
"crypto/sha256"
|
||
|
"crypto/tls"
|
||
|
"fmt"
|
||
|
|
||
|
"github.com/juju/errors"
|
||
|
. "github.com/siddontang/go-mysql/mysql"
|
||
|
)
|
||
|
|
||
|
func (c *Conn) handleAuthSwitchResponse() error {
|
||
|
authData, err := c.readAuthSwitchRequestResponse()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
switch c.authPluginName {
|
||
|
case AUTH_NATIVE_PASSWORD:
|
||
|
if err := c.acquirePassword(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if !bytes.Equal(CalcPassword(c.salt, []byte(c.password)), authData) {
|
||
|
return ErrAccessDenied
|
||
|
}
|
||
|
return nil
|
||
|
|
||
|
case AUTH_CACHING_SHA2_PASSWORD:
|
||
|
if !c.cachingSha2FullAuth {
|
||
|
// Switched auth method but no MoreData packet send yet
|
||
|
if err := c.compareCacheSha2PasswordAuthData(authData); err != nil {
|
||
|
return err
|
||
|
} else {
|
||
|
if c.cachingSha2FullAuth {
|
||
|
return c.handleAuthSwitchResponse()
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
// AuthMoreData packet already sent, do full auth
|
||
|
if err := c.handleCachingSha2PasswordFullAuth(authData); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
c.writeCachingSha2Cache()
|
||
|
return nil
|
||
|
|
||
|
case AUTH_SHA256_PASSWORD:
|
||
|
cont, err := c.handlePublicKeyRetrieval(authData)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if !cont {
|
||
|
return nil
|
||
|
}
|
||
|
if err := c.acquirePassword(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return c.compareSha256PasswordAuthData(authData, c.password)
|
||
|
|
||
|
default:
|
||
|
return errors.Errorf("unknown authentication plugin name '%s'", c.authPluginName)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error {
|
||
|
if err := c.acquirePassword(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if tlsConn, ok := c.Conn.Conn.(*tls.Conn); ok {
|
||
|
if !tlsConn.ConnectionState().HandshakeComplete {
|
||
|
return errors.New("incomplete TSL handshake")
|
||
|
}
|
||
|
// connection is SSL/TLS, client should send plain password
|
||
|
// deal with the trailing \NUL added for plain text password received
|
||
|
if l := len(authData); l != 0 && authData[l-1] == 0x00 {
|
||
|
authData = authData[:l-1]
|
||
|
}
|
||
|
if bytes.Equal(authData, []byte(c.password)) {
|
||
|
return nil
|
||
|
}
|
||
|
return ErrAccessDenied
|
||
|
} else {
|
||
|
// client either request for the public key or send the encrypted password
|
||
|
if len(authData) == 1 && authData[0] == 0x02 {
|
||
|
// send the public key
|
||
|
if err := c.writeAuthMoreDataPubkey(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
// read the encrypted password
|
||
|
var err error
|
||
|
if authData, err = c.readAuthSwitchRequestResponse(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
// the encrypted password
|
||
|
// decrypt
|
||
|
dbytes, err := rsa.DecryptOAEP(sha1.New(), rand.Reader, (c.serverConf.tlsConfig.Certificates[0].PrivateKey).(*rsa.PrivateKey), authData, nil)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
plain := make([]byte, len(c.password)+1)
|
||
|
copy(plain, c.password)
|
||
|
for i := range plain {
|
||
|
j := i % len(c.salt)
|
||
|
plain[i] ^= c.salt[j]
|
||
|
}
|
||
|
if bytes.Equal(plain, dbytes) {
|
||
|
return nil
|
||
|
}
|
||
|
return ErrAccessDenied
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (c *Conn) writeCachingSha2Cache() {
|
||
|
// write cache
|
||
|
if c.password == "" {
|
||
|
return
|
||
|
}
|
||
|
// SHA256(PASSWORD)
|
||
|
crypt := sha256.New()
|
||
|
crypt.Write([]byte(c.password))
|
||
|
m1 := crypt.Sum(nil)
|
||
|
// SHA256(SHA256(PASSWORD))
|
||
|
crypt.Reset()
|
||
|
crypt.Write(m1)
|
||
|
m2 := crypt.Sum(nil)
|
||
|
// caching_sha2_password will maintain an in-memory hash of `user`@`host` => SHA256(SHA256(PASSWORD))
|
||
|
c.serverConf.cacheShaPassword.Store(fmt.Sprintf("%s@%s", c.user, c.Conn.LocalAddr()), m2)
|
||
|
}
|