syncthing/vendor/github.com/lib/pq/conn.go

1848 lines
43 KiB
Go
Raw Normal View History

2016-05-31 20:35:35 +00:00
package pq
import (
"bufio"
"crypto/md5"
"crypto/tls"
"crypto/x509"
"database/sql"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"os/user"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"unicode"
"github.com/lib/pq/oid"
)
// Common error types
var (
ErrNotSupported = errors.New("pq: Unsupported command")
ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.")
ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.")
)
type drv struct{}
func (d *drv) Open(name string) (driver.Conn, error) {
return Open(name)
}
func init() {
sql.Register("postgres", &drv{})
}
type parameterStatus struct {
// server version in the same format as server_version_num, or 0 if
// unavailable
serverVersion int
// the current location based on the TimeZone value of the session, if
// available
currentLocation *time.Location
}
type transactionStatus byte
const (
txnStatusIdle transactionStatus = 'I'
txnStatusIdleInTransaction transactionStatus = 'T'
txnStatusInFailedTransaction transactionStatus = 'E'
)
func (s transactionStatus) String() string {
switch s {
case txnStatusIdle:
return "idle"
case txnStatusIdleInTransaction:
return "idle in transaction"
case txnStatusInFailedTransaction:
return "in a failed transaction"
default:
errorf("unknown transactionStatus %d", s)
}
panic("not reached")
}
type Dialer interface {
Dial(network, address string) (net.Conn, error)
DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
}
type defaultDialer struct{}
func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
return net.Dial(ntw, addr)
}
func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout(ntw, addr, timeout)
}
type conn struct {
c net.Conn
buf *bufio.Reader
namei int
scratch [512]byte
txnStatus transactionStatus
parameterStatus parameterStatus
saveMessageType byte
saveMessageBuffer []byte
// If true, this connection is bad and all public-facing functions should
// return ErrBadConn.
bad bool
// If set, this connection should never use the binary format when
// receiving query results from prepared statements. Only provided for
// debugging.
disablePreparedBinaryResult bool
// Whether to always send []byte parameters over as binary. Enables single
// round-trip mode for non-prepared Query calls.
binaryParameters bool
}
// Handle driver-side settings in parsed connection string.
func (c *conn) handleDriverSettings(o values) (err error) {
boolSetting := func(key string, val *bool) error {
if value := o.Get(key); value != "" {
if value == "yes" {
*val = true
} else if value == "no" {
*val = false
} else {
return fmt.Errorf("unrecognized value %q for %s", value, key)
}
}
return nil
}
err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult)
if err != nil {
return err
}
err = boolSetting("binary_parameters", &c.binaryParameters)
if err != nil {
return err
}
return nil
}
func (c *conn) handlePgpass(o values) {
// if a password was supplied, do not process .pgpass
_, ok := o["password"]
if ok {
return
}
filename := os.Getenv("PGPASSFILE")
if filename == "" {
// XXX this code doesn't work on Windows where the default filename is
// XXX %APPDATA%\postgresql\pgpass.conf
user, err := user.Current()
if err != nil {
return
}
filename = filepath.Join(user.HomeDir, ".pgpass")
}
fileinfo, err := os.Stat(filename)
if err != nil {
return
}
mode := fileinfo.Mode()
if mode&(0x77) != 0 {
// XXX should warn about incorrect .pgpass permissions as psql does
return
}
file, err := os.Open(filename)
if err != nil {
return
}
defer file.Close()
scanner := bufio.NewScanner(io.Reader(file))
hostname := o.Get("host")
ntw, _ := network(o)
port := o.Get("port")
db := o.Get("dbname")
username := o.Get("user")
// From: https://github.com/tg/pgpass/blob/master/reader.go
getFields := func(s string) []string {
fs := make([]string, 0, 5)
f := make([]rune, 0, len(s))
var esc bool
for _, c := range s {
switch {
case esc:
f = append(f, c)
esc = false
case c == '\\':
esc = true
case c == ':':
fs = append(fs, string(f))
f = f[:0]
default:
f = append(f, c)
}
}
return append(fs, string(f))
}
for scanner.Scan() {
line := scanner.Text()
if len(line) == 0 || line[0] == '#' {
continue
}
split := getFields(line)
if len(split) != 5 {
continue
}
if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
o["password"] = split[4]
return
}
}
}
func (c *conn) writeBuf(b byte) *writeBuf {
c.scratch[0] = b
return &writeBuf{
buf: c.scratch[:5],
pos: 1,
}
}
func Open(name string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, name)
}
func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
// Handle any panics during connection initialization. Note that we
// specifically do *not* want to use errRecover(), as that would turn any
// connection errors into ErrBadConns, hiding the real error message from
// the user.
defer errRecoverNoErrBadConn(&err)
o := make(values)
// A number of defaults are applied here, in this order:
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Explicitly passed connection information
o.Set("host", "localhost")
o.Set("port", "5432")
// N.B.: Extra float digits should be set to 3, but that breaks
// Postgres 8.4 and older, where the max is 2.
o.Set("extra_float_digits", "2")
for k, v := range parseEnviron(os.Environ()) {
o.Set(k, v)
}
if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
name, err = ParseURL(name)
if err != nil {
return nil, err
}
}
if err := parseOpts(name, o); err != nil {
return nil, err
}
// Use the "fallback" application name if necessary
if fallback := o.Get("fallback_application_name"); fallback != "" {
if !o.Isset("application_name") {
o.Set("application_name", fallback)
}
}
// We can't work with any client_encoding other than UTF-8 currently.
// However, we have historically allowed the user to set it to UTF-8
// explicitly, and there's no reason to break such programs, so allow that.
// Note that the "options" setting could also set client_encoding, but
// parsing its value is not worth it. Instead, we always explicitly send
// client_encoding as a separate run-time parameter, which should override
// anything set in options.
if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) {
return nil, errors.New("client_encoding must be absent or 'UTF8'")
}
o.Set("client_encoding", "UTF8")
// DateStyle needs a similar treatment.
if datestyle := o.Get("datestyle"); datestyle != "" {
if datestyle != "ISO, MDY" {
panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
"ISO, MDY", datestyle))
}
} else {
o.Set("datestyle", "ISO, MDY")
}
// If a user is not provided by any other means, the last
// resort is to use the current operating system provided user
// name.
if o.Get("user") == "" {
u, err := userCurrent()
if err != nil {
return nil, err
} else {
o.Set("user", u)
}
}
cn := &conn{}
err = cn.handleDriverSettings(o)
if err != nil {
return nil, err
}
cn.handlePgpass(o)
cn.c, err = dial(d, o)
if err != nil {
return nil, err
}
cn.ssl(o)
cn.buf = bufio.NewReader(cn.c)
cn.startup(o)
// reset the deadline, in case one was set (see dial)
if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" {
err = cn.c.SetDeadline(time.Time{})
}
return cn, err
}
func dial(d Dialer, o values) (net.Conn, error) {
ntw, addr := network(o)
// SSL is not necessary or supported over UNIX domain sockets
if ntw == "unix" {
o["sslmode"] = "disable"
}
// Zero or not specified means wait indefinitely.
if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" {
seconds, err := strconv.ParseInt(timeout, 10, 0)
if err != nil {
return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
}
duration := time.Duration(seconds) * time.Second
// connect_timeout should apply to the entire connection establishment
// procedure, so we both use a timeout for the TCP connection
// establishment and set a deadline for doing the initial handshake.
// The deadline is then reset after startup() is done.
deadline := time.Now().Add(duration)
conn, err := d.DialTimeout(ntw, addr, duration)
if err != nil {
return nil, err
}
err = conn.SetDeadline(deadline)
return conn, err
}
return d.Dial(ntw, addr)
}
func network(o values) (string, string) {
host := o.Get("host")
if strings.HasPrefix(host, "/") {
sockPath := path.Join(host, ".s.PGSQL."+o.Get("port"))
return "unix", sockPath
}
return "tcp", net.JoinHostPort(host, o.Get("port"))
}
type values map[string]string
func (vs values) Set(k, v string) {
vs[k] = v
}
func (vs values) Get(k string) (v string) {
return vs[k]
}
func (vs values) Isset(k string) bool {
_, ok := vs[k]
return ok
}
// scanner implements a tokenizer for libpq-style option strings.
type scanner struct {
s []rune
i int
}
// newScanner returns a new scanner initialized with the option string s.
func newScanner(s string) *scanner {
return &scanner{[]rune(s), 0}
}
// Next returns the next rune.
// It returns 0, false if the end of the text has been reached.
func (s *scanner) Next() (rune, bool) {
if s.i >= len(s.s) {
return 0, false
}
r := s.s[s.i]
s.i++
return r, true
}
// SkipSpaces returns the next non-whitespace rune.
// It returns 0, false if the end of the text has been reached.
func (s *scanner) SkipSpaces() (rune, bool) {
r, ok := s.Next()
for unicode.IsSpace(r) && ok {
r, ok = s.Next()
}
return r, ok
}
// parseOpts parses the options from name and adds them to the values.
//
// The parsing code is based on conninfo_parse from libpq's fe-connect.c
func parseOpts(name string, o values) error {
s := newScanner(name)
for {
var (
keyRunes, valRunes []rune
r rune
ok bool
)
if r, ok = s.SkipSpaces(); !ok {
break
}
// Scan the key
for !unicode.IsSpace(r) && r != '=' {
keyRunes = append(keyRunes, r)
if r, ok = s.Next(); !ok {
break
}
}
// Skip any whitespace if we're not at the = yet
if r != '=' {
r, ok = s.SkipSpaces()
}
// The current character should be =
if r != '=' || !ok {
return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
}
// Skip any whitespace after the =
if r, ok = s.SkipSpaces(); !ok {
// If we reach the end here, the last value is just an empty string as per libpq.
o.Set(string(keyRunes), "")
break
}
if r != '\'' {
for !unicode.IsSpace(r) {
if r == '\\' {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`missing character after backslash`)
}
}
valRunes = append(valRunes, r)
if r, ok = s.Next(); !ok {
break
}
}
} else {
quote:
for {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`unterminated quoted string literal in connection string`)
}
switch r {
case '\'':
break quote
case '\\':
r, _ = s.Next()
fallthrough
default:
valRunes = append(valRunes, r)
}
}
}
o.Set(string(keyRunes), string(valRunes))
}
return nil
}
func (cn *conn) isInTransaction() bool {
return cn.txnStatus == txnStatusIdleInTransaction ||
cn.txnStatus == txnStatusInFailedTransaction
}
func (cn *conn) checkIsInTransaction(intxn bool) {
if cn.isInTransaction() != intxn {
cn.bad = true
errorf("unexpected transaction status %v", cn.txnStatus)
}
}
func (cn *conn) Begin() (_ driver.Tx, err error) {
if cn.bad {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)
cn.checkIsInTransaction(false)
_, commandTag, err := cn.simpleExec("BEGIN")
if err != nil {
return nil, err
}
if commandTag != "BEGIN" {
cn.bad = true
return nil, fmt.Errorf("unexpected command tag %s", commandTag)
}
if cn.txnStatus != txnStatusIdleInTransaction {
cn.bad = true
return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
}
return cn, nil
}
func (cn *conn) Commit() (err error) {
if cn.bad {
return driver.ErrBadConn
}
defer cn.errRecover(&err)
cn.checkIsInTransaction(true)
// We don't want the client to think that everything is okay if it tries
// to commit a failed transaction. However, no matter what we return,
// database/sql will release this connection back into the free connection
// pool so we have to abort the current transaction here. Note that you
// would get the same behaviour if you issued a COMMIT in a failed
// transaction, so it's also the least surprising thing to do here.
if cn.txnStatus == txnStatusInFailedTransaction {
if err := cn.Rollback(); err != nil {
return err
}
return ErrInFailedTransaction
}
_, commandTag, err := cn.simpleExec("COMMIT")
if err != nil {
if cn.isInTransaction() {
cn.bad = true
}
return err
}
if commandTag != "COMMIT" {
cn.bad = true
return fmt.Errorf("unexpected command tag %s", commandTag)
}
cn.checkIsInTransaction(false)
return nil
}
func (cn *conn) Rollback() (err error) {
if cn.bad {
return driver.ErrBadConn
}
defer cn.errRecover(&err)
cn.checkIsInTransaction(true)
_, commandTag, err := cn.simpleExec("ROLLBACK")
if err != nil {
if cn.isInTransaction() {
cn.bad = true
}
return err
}
if commandTag != "ROLLBACK" {
return fmt.Errorf("unexpected command tag %s", commandTag)
}
cn.checkIsInTransaction(false)
return nil
}
func (cn *conn) gname() string {
cn.namei++
return strconv.FormatInt(int64(cn.namei), 10)
}
func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
b := cn.writeBuf('Q')
b.string(q)
cn.send(b)
for {
t, r := cn.recv1()
switch t {
case 'C':
res, commandTag = cn.parseComplete(r.string())
case 'Z':
cn.processReadyForQuery(r)
// done
return
case 'E':
err = parseError(r)
case 'T', 'D', 'I':
// ignore any results
default:
cn.bad = true
errorf("unknown response for simple query: %q", t)
}
}
}
func (cn *conn) simpleQuery(q string) (res *rows, err error) {
defer cn.errRecover(&err)
b := cn.writeBuf('Q')
b.string(q)
cn.send(b)
for {
t, r := cn.recv1()
switch t {
case 'C', 'I':
// We allow queries which don't return any results through Query as
// well as Exec. We still have to give database/sql a rows object
// the user can close, though, to avoid connections from being
// leaked. A "rows" with done=true works fine for that purpose.
if err != nil {
cn.bad = true
errorf("unexpected message %q in simple query execution", t)
}
if res == nil {
res = &rows{
cn: cn,
}
}
res.done = true
case 'Z':
cn.processReadyForQuery(r)
// done
return
case 'E':
res = nil
err = parseError(r)
case 'D':
if res == nil {
cn.bad = true
errorf("unexpected DataRow in simple query execution")
}
// the query didn't fail; kick off to Next
cn.saveMessage(t, r)
return
case 'T':
// res might be non-nil here if we received a previous
// CommandComplete, but that's fine; just overwrite it
res = &rows{cn: cn}
res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
default:
cn.bad = true
errorf("unknown response for simple query: %q", t)
}
}
}
// Decides which column formats to use for a prepared statement. The input is
// an array of type oids, one element per result column.
func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) {
if len(colTyps) == 0 {
return nil, colFmtDataAllText
}
colFmts = make([]format, len(colTyps))
if forceText {
return colFmts, colFmtDataAllText
}
allBinary := true
allText := true
for i, o := range colTyps {
switch o {
// This is the list of types to use binary mode for when receiving them
// through a prepared statement. If a type appears in this list, it
// must also be implemented in binaryDecode in encode.go.
case oid.T_bytea:
fallthrough
case oid.T_int8:
fallthrough
case oid.T_int4:
fallthrough
case oid.T_int2:
colFmts[i] = formatBinary
allText = false
default:
allBinary = false
}
}
if allBinary {
return colFmts, colFmtDataAllBinary
} else if allText {
return colFmts, colFmtDataAllText
} else {
colFmtData = make([]byte, 2+len(colFmts)*2)
binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
for i, v := range colFmts {
binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
}
return colFmts, colFmtData
}
}
func (cn *conn) prepareTo(q, stmtName string) *stmt {
st := &stmt{cn: cn, name: stmtName}
b := cn.writeBuf('P')
b.string(st.name)
b.string(q)
b.int16(0)
b.next('D')
b.byte('S')
b.string(st.name)
b.next('S')
cn.send(b)
cn.readParseResponse()
st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
cn.readReadyForQuery()
return st
}
func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
if cn.bad {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)
if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
return cn.prepareCopyIn(q)
}
return cn.prepareTo(q, cn.gname()), nil
}
func (cn *conn) Close() (err error) {
if cn.bad {
return driver.ErrBadConn
}
defer cn.errRecover(&err)
// Don't go through send(); ListenerConn relies on us not scribbling on the
// scratch buffer of this connection.
err = cn.sendSimpleMessage('X')
if err != nil {
return err
}
return cn.c.Close()
}
// Implement the "Queryer" interface
func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) {
if cn.bad {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)
// Check to see if we can use the "simpleQuery" interface, which is
// *much* faster than going through prepare/exec
if len(args) == 0 {
return cn.simpleQuery(query)
}
if cn.binaryParameters {
cn.sendBinaryModeQuery(query, args)
cn.readParseResponse()
cn.readBindResponse()
rows := &rows{cn: cn}
rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
cn.postExecuteWorkaround()
return rows, nil
} else {
st := cn.prepareTo(query, "")
st.exec(args)
return &rows{
cn: cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
}, nil
}
}
// Implement the optional "Execer" interface for one-shot queries
func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
if cn.bad {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)
// Check to see if we can use the "simpleExec" interface, which is
// *much* faster than going through prepare/exec
if len(args) == 0 {
// ignore commandTag, our caller doesn't care
r, _, err := cn.simpleExec(query)
return r, err
}
if cn.binaryParameters {
cn.sendBinaryModeQuery(query, args)
cn.readParseResponse()
cn.readBindResponse()
cn.readPortalDescribeResponse()
cn.postExecuteWorkaround()
res, _, err = cn.readExecuteResponse("Execute")
return res, err
} else {
// Use the unnamed statement to defer planning until bind
// time, or else value-based selectivity estimates cannot be
// used.
st := cn.prepareTo(query, "")
r, err := st.Exec(args)
if err != nil {
panic(err)
}
return r, err
}
}
func (cn *conn) send(m *writeBuf) {
_, err := cn.c.Write(m.wrap())
if err != nil {
panic(err)
}
}
func (cn *conn) sendStartupPacket(m *writeBuf) {
// sanity check
if m.buf[0] != 0 {
panic("oops")
}
_, err := cn.c.Write((m.wrap())[1:])
if err != nil {
panic(err)
}
}
// Send a message of type typ to the server on the other end of cn. The
// message should have no payload. This method does not use the scratch
// buffer.
func (cn *conn) sendSimpleMessage(typ byte) (err error) {
_, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
return err
}
// saveMessage memorizes a message and its buffer in the conn struct.
// recvMessage will then return these values on the next call to it. This
// method is useful in cases where you have to see what the next message is
// going to be (e.g. to see whether it's an error or not) but you can't handle
// the message yourself.
func (cn *conn) saveMessage(typ byte, buf *readBuf) {
if cn.saveMessageType != 0 {
cn.bad = true
errorf("unexpected saveMessageType %d", cn.saveMessageType)
}
cn.saveMessageType = typ
cn.saveMessageBuffer = *buf
}
// recvMessage receives any message from the backend, or returns an error if
// a problem occurred while reading the message.
func (cn *conn) recvMessage(r *readBuf) (byte, error) {
// workaround for a QueryRow bug, see exec
if cn.saveMessageType != 0 {
t := cn.saveMessageType
*r = cn.saveMessageBuffer
cn.saveMessageType = 0
cn.saveMessageBuffer = nil
return t, nil
}
x := cn.scratch[:5]
_, err := io.ReadFull(cn.buf, x)
if err != nil {
return 0, err
}
// read the type and length of the message that follows
t := x[0]
n := int(binary.BigEndian.Uint32(x[1:])) - 4
var y []byte
if n <= len(cn.scratch) {
y = cn.scratch[:n]
} else {
y = make([]byte, n)
}
_, err = io.ReadFull(cn.buf, y)
if err != nil {
return 0, err
}
*r = y
return t, nil
}
// recv receives a message from the backend, but if an error happened while
// reading the message or the received message was an ErrorResponse, it panics.
// NoticeResponses are ignored. This function should generally be used only
// during the startup sequence.
func (cn *conn) recv() (t byte, r *readBuf) {
for {
var err error
r = &readBuf{}
t, err = cn.recvMessage(r)
if err != nil {
panic(err)
}
switch t {
case 'E':
panic(parseError(r))
case 'N':
// ignore
default:
return
}
}
}
// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
// the caller to avoid an allocation.
func (cn *conn) recv1Buf(r *readBuf) byte {
for {
t, err := cn.recvMessage(r)
if err != nil {
panic(err)
}
switch t {
case 'A', 'N':
// ignore
case 'S':
cn.processParameterStatus(r)
default:
return t
}
}
}
// recv1 receives a message from the backend, panicking if an error occurs
// while attempting to read it. All asynchronous messages are ignored, with
// the exception of ErrorResponse.
func (cn *conn) recv1() (t byte, r *readBuf) {
r = &readBuf{}
t = cn.recv1Buf(r)
return t, r
}
func (cn *conn) ssl(o values) {
verifyCaOnly := false
tlsConf := tls.Config{}
switch mode := o.Get("sslmode"); mode {
case "require", "":
tlsConf.InsecureSkipVerify = true
case "verify-ca":
// We must skip TLS's own verification since it requires full
// verification since Go 1.3.
tlsConf.InsecureSkipVerify = true
verifyCaOnly = true
case "verify-full":
tlsConf.ServerName = o.Get("host")
case "disable":
return
default:
errorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
}
cn.setupSSLClientCertificates(&tlsConf, o)
cn.setupSSLCA(&tlsConf, o)
w := cn.writeBuf(0)
w.int32(80877103)
cn.sendStartupPacket(w)
b := cn.scratch[:1]
_, err := io.ReadFull(cn.c, b)
if err != nil {
panic(err)
}
if b[0] != 'S' {
panic(ErrSSLNotSupported)
}
client := tls.Client(cn.c, &tlsConf)
if verifyCaOnly {
cn.verifyCA(client, &tlsConf)
}
cn.c = client
}
// verifyCA carries out a TLS handshake to the server and verifies the
// presented certificate against the effective CA, i.e. the one specified in
// sslrootcert or the system CA if sslrootcert was not specified.
func (cn *conn) verifyCA(client *tls.Conn, tlsConf *tls.Config) {
err := client.Handshake()
if err != nil {
panic(err)
}
certs := client.ConnectionState().PeerCertificates
opts := x509.VerifyOptions{
DNSName: client.ConnectionState().ServerName,
Intermediates: x509.NewCertPool(),
Roots: tlsConf.RootCAs,
}
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
_, err = certs[0].Verify(opts)
if err != nil {
panic(err)
}
}
// This function sets up SSL client certificates based on either the "sslkey"
// and "sslcert" settings (possibly set via the environment variables PGSSLKEY
// and PGSSLCERT, respectively), or if they aren't set, from the .postgresql
// directory in the user's home directory. If the file paths are set
// explicitly, the files must exist. The key file must also not be
// world-readable, or this function will panic with
// ErrSSLKeyHasWorldPermissions.
func (cn *conn) setupSSLClientCertificates(tlsConf *tls.Config, o values) {
var missingOk bool
sslkey := o.Get("sslkey")
sslcert := o.Get("sslcert")
if sslkey != "" && sslcert != "" {
// If the user has set an sslkey and sslcert, they *must* exist.
missingOk = false
} else {
// Automatically load certificates from ~/.postgresql.
user, err := user.Current()
if err != nil {
// user.Current() might fail when cross-compiling. We have to
// ignore the error and continue without client certificates, since
// we wouldn't know where to load them from.
return
}
sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
missingOk = true
}
// Check that both files exist, and report the error or stop, depending on
// which behaviour we want. Note that we don't do any more extensive
// checks than this (such as checking that the paths aren't directories);
// LoadX509KeyPair() will take care of the rest.
keyfinfo, err := os.Stat(sslkey)
if err != nil && missingOk {
return
} else if err != nil {
panic(err)
}
_, err = os.Stat(sslcert)
if err != nil && missingOk {
return
} else if err != nil {
panic(err)
}
// If we got this far, the key file must also have the correct permissions
kmode := keyfinfo.Mode()
if kmode != kmode&0600 {
panic(ErrSSLKeyHasWorldPermissions)
}
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
if err != nil {
panic(err)
}
tlsConf.Certificates = []tls.Certificate{cert}
}
// Sets up RootCAs in the TLS configuration if sslrootcert is set.
func (cn *conn) setupSSLCA(tlsConf *tls.Config, o values) {
if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" {
tlsConf.RootCAs = x509.NewCertPool()
cert, err := ioutil.ReadFile(sslrootcert)
if err != nil {
panic(err)
}
ok := tlsConf.RootCAs.AppendCertsFromPEM(cert)
if !ok {
errorf("couldn't parse pem in sslrootcert")
}
}
}
// isDriverSetting returns true iff a setting is purely for configuring the
// driver's options and should not be sent to the server in the connection
// startup packet.
func isDriverSetting(key string) bool {
switch key {
case "host", "port":
return true
case "password":
return true
case "sslmode", "sslcert", "sslkey", "sslrootcert":
return true
case "fallback_application_name":
return true
case "connect_timeout":
return true
case "disable_prepared_binary_result":
return true
case "binary_parameters":
return true
default:
return false
}
}
func (cn *conn) startup(o values) {
w := cn.writeBuf(0)
w.int32(196608)
// Send the backend the name of the database we want to connect to, and the
// user we want to connect as. Additionally, we send over any run-time
// parameters potentially included in the connection string. If the server
// doesn't recognize any of them, it will reply with an error.
for k, v := range o {
if isDriverSetting(k) {
// skip options which can't be run-time parameters
continue
}
// The protocol requires us to supply the database name as "database"
// instead of "dbname".
if k == "dbname" {
k = "database"
}
w.string(k)
w.string(v)
}
w.string("")
cn.sendStartupPacket(w)
for {
t, r := cn.recv()
switch t {
case 'K':
case 'S':
cn.processParameterStatus(r)
case 'R':
cn.auth(r, o)
case 'Z':
cn.processReadyForQuery(r)
return
default:
errorf("unknown response for startup: %q", t)
}
}
}
func (cn *conn) auth(r *readBuf, o values) {
switch code := r.int32(); code {
case 0:
// OK
case 3:
w := cn.writeBuf('p')
w.string(o.Get("password"))
cn.send(w)
t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 0 {
errorf("unexpected authentication response: %q", t)
}
case 5:
s := string(r.next(4))
w := cn.writeBuf('p')
w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s))
cn.send(w)
t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 0 {
errorf("unexpected authentication response: %q", t)
}
default:
errorf("unknown authentication response: %d", code)
}
}
type format int
const formatText format = 0
const formatBinary format = 1
// One result-column format code with the value 1 (i.e. all binary).
var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1}
// No result-column format codes (i.e. all text).
var colFmtDataAllText []byte = []byte{0, 0}
type stmt struct {
cn *conn
name string
colNames []string
colFmts []format
colFmtData []byte
colTyps []oid.Oid
paramTyps []oid.Oid
closed bool
}
func (st *stmt) Close() (err error) {
if st.closed {
return nil
}
if st.cn.bad {
return driver.ErrBadConn
}
defer st.cn.errRecover(&err)
w := st.cn.writeBuf('C')
w.byte('S')
w.string(st.name)
st.cn.send(w)
st.cn.send(st.cn.writeBuf('S'))
t, _ := st.cn.recv1()
if t != '3' {
st.cn.bad = true
errorf("unexpected close response: %q", t)
}
st.closed = true
t, r := st.cn.recv1()
if t != 'Z' {
st.cn.bad = true
errorf("expected ready for query, but got: %q", t)
}
st.cn.processReadyForQuery(r)
return nil
}
func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
if st.cn.bad {
return nil, driver.ErrBadConn
}
defer st.cn.errRecover(&err)
st.exec(v)
return &rows{
cn: st.cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
}, nil
}
func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
if st.cn.bad {
return nil, driver.ErrBadConn
}
defer st.cn.errRecover(&err)
st.exec(v)
res, _, err = st.cn.readExecuteResponse("simple query")
return res, err
}
func (st *stmt) exec(v []driver.Value) {
if len(v) >= 65536 {
errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
}
if len(v) != len(st.paramTyps) {
errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
}
cn := st.cn
w := cn.writeBuf('B')
w.byte(0) // unnamed portal
w.string(st.name)
if cn.binaryParameters {
cn.sendBinaryParameters(w, v)
} else {
w.int16(0)
w.int16(len(v))
for i, x := range v {
if x == nil {
w.int32(-1)
} else {
b := encode(&cn.parameterStatus, x, st.paramTyps[i])
w.int32(len(b))
w.bytes(b)
}
}
}
w.bytes(st.colFmtData)
w.next('E')
w.byte(0)
w.int32(0)
w.next('S')
cn.send(w)
cn.readBindResponse()
cn.postExecuteWorkaround()
}
func (st *stmt) NumInput() int {
return len(st.paramTyps)
}
// parseComplete parses the "command tag" from a CommandComplete message, and
// returns the number of rows affected (if applicable) and a string
// identifying only the command that was executed, e.g. "ALTER TABLE". If the
// command tag could not be parsed, parseComplete panics.
func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
commandsWithAffectedRows := []string{
"SELECT ",
// INSERT is handled below
"UPDATE ",
"DELETE ",
"FETCH ",
"MOVE ",
"COPY ",
}
var affectedRows *string
for _, tag := range commandsWithAffectedRows {
if strings.HasPrefix(commandTag, tag) {
t := commandTag[len(tag):]
affectedRows = &t
commandTag = tag[:len(tag)-1]
break
}
}
// INSERT also includes the oid of the inserted row in its command tag.
// Oids in user tables are deprecated, and the oid is only returned when
// exactly one row is inserted, so it's unlikely to be of value to any
// real-world application and we can ignore it.
if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
parts := strings.Split(commandTag, " ")
if len(parts) != 3 {
cn.bad = true
errorf("unexpected INSERT command tag %s", commandTag)
}
affectedRows = &parts[len(parts)-1]
commandTag = "INSERT"
}
// There should be no affected rows attached to the tag, just return it
if affectedRows == nil {
return driver.RowsAffected(0), commandTag
}
n, err := strconv.ParseInt(*affectedRows, 10, 64)
if err != nil {
cn.bad = true
errorf("could not parse commandTag: %s", err)
}
return driver.RowsAffected(n), commandTag
}
type rows struct {
cn *conn
colNames []string
colTyps []oid.Oid
colFmts []format
done bool
rb readBuf
}
func (rs *rows) Close() error {
// no need to look at cn.bad as Next() will
for {
err := rs.Next(nil)
switch err {
case nil:
case io.EOF:
return nil
default:
return err
}
}
}
func (rs *rows) Columns() []string {
return rs.colNames
}
func (rs *rows) Next(dest []driver.Value) (err error) {
if rs.done {
return io.EOF
}
conn := rs.cn
if conn.bad {
return driver.ErrBadConn
}
defer conn.errRecover(&err)
for {
t := conn.recv1Buf(&rs.rb)
switch t {
case 'E':
err = parseError(&rs.rb)
case 'C', 'I':
continue
case 'Z':
conn.processReadyForQuery(&rs.rb)
rs.done = true
if err != nil {
return err
}
return io.EOF
case 'D':
n := rs.rb.int16()
if err != nil {
conn.bad = true
errorf("unexpected DataRow after error %s", err)
}
if n < len(dest) {
dest = dest[:n]
}
for i := range dest {
l := rs.rb.int32()
if l == -1 {
dest[i] = nil
continue
}
dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i])
}
return
default:
errorf("unexpected message after execute: %q", t)
}
}
}
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
// used as part of an SQL statement. For example:
//
// tblname := "my_table"
// data := "my_data"
// err = db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", pq.QuoteIdentifier(tblname)), data)
//
// Any double quotes in name will be escaped. The quoted identifier will be
// case sensitive when used in a query. If the input string contains a zero
// byte, the result will be truncated immediately before it.
func QuoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
}
func md5s(s string) string {
h := md5.New()
h.Write([]byte(s))
return fmt.Sprintf("%x", h.Sum(nil))
}
func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
// Do one pass over the parameters to see if we're going to send any of
// them over in binary. If we are, create a paramFormats array at the
// same time.
var paramFormats []int
for i, x := range args {
_, ok := x.([]byte)
if ok {
if paramFormats == nil {
paramFormats = make([]int, len(args))
}
paramFormats[i] = 1
}
}
if paramFormats == nil {
b.int16(0)
} else {
b.int16(len(paramFormats))
for _, x := range paramFormats {
b.int16(x)
}
}
b.int16(len(args))
for _, x := range args {
if x == nil {
b.int32(-1)
} else {
datum := binaryEncode(&cn.parameterStatus, x)
b.int32(len(datum))
b.bytes(datum)
}
}
}
func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
if len(args) >= 65536 {
errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
}
b := cn.writeBuf('P')
b.byte(0) // unnamed statement
b.string(query)
b.int16(0)
b.next('B')
b.int16(0) // unnamed portal and statement
cn.sendBinaryParameters(b, args)
b.bytes(colFmtDataAllText)
b.next('D')
b.byte('P')
b.byte(0) // unnamed portal
b.next('E')
b.byte(0)
b.int32(0)
b.next('S')
cn.send(b)
}
func (c *conn) processParameterStatus(r *readBuf) {
var err error
param := r.string()
switch param {
case "server_version":
var major1 int
var major2 int
var minor int
_, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
if err == nil {
c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
}
case "TimeZone":
c.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
if err != nil {
c.parameterStatus.currentLocation = nil
}
default:
// ignore
}
}
func (c *conn) processReadyForQuery(r *readBuf) {
c.txnStatus = transactionStatus(r.byte())
}
func (cn *conn) readReadyForQuery() {
t, r := cn.recv1()
switch t {
case 'Z':
cn.processReadyForQuery(r)
return
default:
cn.bad = true
errorf("unexpected message %q; expected ReadyForQuery", t)
}
}
func (cn *conn) readParseResponse() {
t, r := cn.recv1()
switch t {
case '1':
return
case 'E':
err := parseError(r)
cn.readReadyForQuery()
panic(err)
default:
cn.bad = true
errorf("unexpected Parse response %q", t)
}
}
func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) {
for {
t, r := cn.recv1()
switch t {
case 't':
nparams := r.int16()
paramTyps = make([]oid.Oid, nparams)
for i := range paramTyps {
paramTyps[i] = r.oid()
}
case 'n':
return paramTyps, nil, nil
case 'T':
colNames, colTyps = parseStatementRowDescribe(r)
return paramTyps, colNames, colTyps
case 'E':
err := parseError(r)
cn.readReadyForQuery()
panic(err)
default:
cn.bad = true
errorf("unexpected Describe statement response %q", t)
}
}
}
func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) {
t, r := cn.recv1()
switch t {
case 'T':
return parsePortalRowDescribe(r)
case 'n':
return nil, nil, nil
case 'E':
err := parseError(r)
cn.readReadyForQuery()
panic(err)
default:
cn.bad = true
errorf("unexpected Describe response %q", t)
}
panic("not reached")
}
func (cn *conn) readBindResponse() {
t, r := cn.recv1()
switch t {
case '2':
return
case 'E':
err := parseError(r)
cn.readReadyForQuery()
panic(err)
default:
cn.bad = true
errorf("unexpected Bind response %q", t)
}
}
func (cn *conn) postExecuteWorkaround() {
// Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
// any errors from rows.Next, which masks errors that happened during the
// execution of the query. To avoid the problem in common cases, we wait
// here for one more message from the database. If it's not an error the
// query will likely succeed (or perhaps has already, if it's a
// CommandComplete), so we push the message into the conn struct; recv1
// will return it as the next message for rows.Next or rows.Close.
// However, if it's an error, we wait until ReadyForQuery and then return
// the error to our caller.
for {
t, r := cn.recv1()
switch t {
case 'E':
err := parseError(r)
cn.readReadyForQuery()
panic(err)
case 'C', 'D', 'I':
// the query didn't fail, but we can't process this message
cn.saveMessage(t, r)
return
default:
cn.bad = true
errorf("unexpected message during extended query execution: %q", t)
}
}
}
// Only for Exec(), since we ignore the returned data
func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
for {
t, r := cn.recv1()
switch t {
case 'C':
if err != nil {
cn.bad = true
errorf("unexpected CommandComplete after error %s", err)
}
res, commandTag = cn.parseComplete(r.string())
case 'Z':
cn.processReadyForQuery(r)
return res, commandTag, err
case 'E':
err = parseError(r)
case 'T', 'D', 'I':
if err != nil {
cn.bad = true
errorf("unexpected %q after error %s", t, err)
}
// ignore any results
default:
cn.bad = true
errorf("unknown %s response: %q", protocolState, t)
}
}
}
func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) {
n := r.int16()
colNames = make([]string, n)
colTyps = make([]oid.Oid, n)
for i := range colNames {
colNames[i] = r.string()
r.next(6)
colTyps[i] = r.oid()
r.next(6)
// format code not known when describing a statement; always 0
r.next(2)
}
return
}
func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) {
n := r.int16()
colNames = make([]string, n)
colFmts = make([]format, n)
colTyps = make([]oid.Oid, n)
for i := range colNames {
colNames[i] = r.string()
r.next(6)
colTyps[i] = r.oid()
r.next(6)
colFmts[i] = format(r.int16())
}
return
}
// parseEnviron tries to mimic some of libpq's environment handling
//
// To ease testing, it does not directly reference os.Environ, but is
// designed to accept its output.
//
// Environment-set connection information is intended to have a higher
// precedence than a library default but lower than any explicitly
// passed information (such as in the URL or connection string).
func parseEnviron(env []string) (out map[string]string) {
out = make(map[string]string)
for _, v := range env {
parts := strings.SplitN(v, "=", 2)
accrue := func(keyname string) {
out[keyname] = parts[1]
}
unsupported := func() {
panic(fmt.Sprintf("setting %v not supported", parts[0]))
}
// The order of these is the same as is seen in the
// PostgreSQL 9.1 manual. Unsupported but well-defined
// keys cause a panic; these should be unset prior to
// execution. Options which pq expects to be set to a
// certain value are allowed, but must be set to that
// value if present (they can, of course, be absent).
switch parts[0] {
case "PGHOST":
accrue("host")
case "PGHOSTADDR":
unsupported()
case "PGPORT":
accrue("port")
case "PGDATABASE":
accrue("dbname")
case "PGUSER":
accrue("user")
case "PGPASSWORD":
accrue("password")
case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
unsupported()
case "PGOPTIONS":
accrue("options")
case "PGAPPNAME":
accrue("application_name")
case "PGSSLMODE":
accrue("sslmode")
case "PGSSLCERT":
accrue("sslcert")
case "PGSSLKEY":
accrue("sslkey")
case "PGSSLROOTCERT":
accrue("sslrootcert")
case "PGREQUIRESSL", "PGSSLCRL":
unsupported()
case "PGREQUIREPEER":
unsupported()
case "PGKRBSRVNAME", "PGGSSLIB":
unsupported()
case "PGCONNECT_TIMEOUT":
accrue("connect_timeout")
case "PGCLIENTENCODING":
accrue("client_encoding")
case "PGDATESTYLE":
accrue("datestyle")
case "PGTZ":
accrue("timezone")
case "PGGEQO":
accrue("geqo")
case "PGSYSCONFDIR", "PGLOCALEDIR":
unsupported()
}
}
return out
}
// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
func isUTF8(name string) bool {
// Recognize all sorts of silly things as "UTF-8", like Postgres does
s := strings.Map(alnumLowerASCII, name)
return s == "utf8" || s == "unicode"
}
func alnumLowerASCII(ch rune) rune {
if 'A' <= ch && ch <= 'Z' {
return ch + ('a' - 'A')
}
if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
return ch
}
return -1 // discard
}