feat: Add TLS Handshake timeout support (#530)
* feat: Add TLS Handshake timeout support
Add support for configuring a timeout for TLS Handshake call via
DialTLSHandshakeTimeout DialOption. If no option is specified then the
default timeout is 10 seconds.
Also:
* Add a default connect timeout of 30 seconds matching that of net/http.
Fixes #509
diff --git a/redis/conn.go b/redis/conn.go
index 33b43be..5d7841c 100644
--- a/redis/conn.go
+++ b/redis/conn.go
@@ -75,17 +75,27 @@
}
type dialOptions struct {
- readTimeout time.Duration
- writeTimeout time.Duration
- dialer *net.Dialer
- dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
- db int
- username string
- password string
- clientName string
- useTLS bool
- skipVerify bool
- tlsConfig *tls.Config
+ readTimeout time.Duration
+ writeTimeout time.Duration
+ tlsHandshakeTimeout time.Duration
+ dialer *net.Dialer
+ dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+ db int
+ username string
+ password string
+ clientName string
+ useTLS bool
+ skipVerify bool
+ tlsConfig *tls.Config
+}
+
+// DialTLSHandshakeTimeout specifies the maximum amount of time waiting to
+// wait for a TLS handshake. Zero means no timeout.
+// If no DialTLSHandshakeTimeout option is specified then the default is 30 seconds.
+func DialTLSHandshakeTimeout(d time.Duration) DialOption {
+ return DialOption{func(do *dialOptions) {
+ do.tlsHandshakeTimeout = d
+ }}
}
// DialReadTimeout specifies the timeout for reading a single command reply.
@@ -104,6 +114,7 @@
// DialConnectTimeout specifies the timeout for connecting to the Redis server when
// no DialNetDial option is specified.
+// If no DialConnectTimeout option is specified then the default is 30 seconds.
func DialConnectTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.dialer.Timeout = d
@@ -201,13 +212,21 @@
return DialContext(context.Background(), network, address, options...)
}
+type tlsHandshakeTimeoutError struct{}
+
+func (tlsHandshakeTimeoutError) Timeout() bool { return true }
+func (tlsHandshakeTimeoutError) Temporary() bool { return true }
+func (tlsHandshakeTimeoutError) Error() string { return "TLS handshake timeout" }
+
// DialContext connects to the Redis server at the given network and
// address using the specified options and context.
func DialContext(ctx context.Context, network, address string, options ...DialOption) (Conn, error) {
do := dialOptions{
dialer: &net.Dialer{
+ Timeout: time.Second * 30,
KeepAlive: time.Minute * 5,
},
+ tlsHandshakeTimeout: time.Second * 10,
}
for _, option := range options {
option.f(&do)
@@ -238,10 +257,22 @@
}
tlsConn := tls.Client(netConn, tlsConfig)
- if err := tlsConn.Handshake(); err != nil {
- netConn.Close()
+ errc := make(chan error, 2) // buffered so we don't block timeout or Handshake
+ if d := do.tlsHandshakeTimeout; d != 0 {
+ timer := time.AfterFunc(d, func() {
+ errc <- tlsHandshakeTimeoutError{}
+ })
+ defer timer.Stop()
+ }
+ go func() {
+ errc <- tlsConn.Handshake()
+ }()
+ if err := <-errc; err != nil {
+ // Timeout or Handshake error.
+ netConn.Close() // nolint: errcheck
return nil, err
}
+
netConn = tlsConn
}
diff --git a/redis/conn_test.go b/redis/conn_test.go
index dbc66e7..97d7bec 100644
--- a/redis/conn_test.go
+++ b/redis/conn_test.go
@@ -701,6 +701,45 @@
checkPingPong(t, &buf, c)
}
+type blockedReader struct {
+ ch chan struct{}
+}
+
+func (b blockedReader) Read(p []byte) (n int, err error) {
+ <-b.ch
+ return 0, nil
+}
+
+func dialTestBlockedConn(ch chan struct{}, w io.Writer) redis.DialOption {
+ return redis.DialNetDial(func(network, addr string) (net.Conn, error) {
+ return &testConn{Reader: blockedReader{ch: ch}, Writer: w}, nil
+ })
+}
+
+func TestDialTLSHandshakeTimeout(t *testing.T) {
+ var buf bytes.Buffer
+ ch := make(chan struct{})
+ var err error
+ go func() {
+ _, err = redis.Dial("tcp", "example.com:6379",
+ redis.DialTLSConfig(&clientTLSConfig),
+ redis.DialTLSHandshakeTimeout(time.Millisecond),
+ dialTestBlockedConn(ch, &buf),
+ redis.DialUseTLS(true))
+ close(ch)
+ }()
+ select {
+ case <-time.After(time.Second):
+ t.Fatal("dial didn't timeout")
+ case <-ch:
+ if err == nil {
+ t.Fatal("dial didn't error")
+ } else if err.Error() != "TLS handshake timeout" {
+ t.Fatal("dial unexpected error:", err)
+ }
+ }
+}
+
func TestDialTLSSKipVerify(t *testing.T) {
var buf bytes.Buffer
c, err := redis.Dial("tcp", "example.com:6379",