blob: a35ae895b29bd45bf0149b207fa4d692f1ae753a [file] [log] [blame] [edit]
// Copyright 2015 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build go1.8
package config
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
yaml "gopkg.in/yaml.v2"
)
const (
TLSCAChainPath = "testdata/tls-ca-chain.pem"
ServerCertificatePath = "testdata/server.crt"
ServerKeyPath = "testdata/server.key"
ClientCertificatePath = "testdata/client.crt"
ClientKeyNoPassPath = "testdata/client-no-pass.key"
InvalidCA = "testdata/client-no-pass.key"
WrongClientCertPath = "testdata/self-signed-client.crt"
WrongClientKeyPath = "testdata/self-signed-client.key"
EmptyFile = "testdata/empty"
MissingCA = "missing/ca.crt"
MissingCert = "missing/cert.crt"
MissingKey = "missing/secret.key"
ExpectedMessage = "I'm here to serve you!!!"
AuthorizationCredentials = "theanswertothegreatquestionoflifetheuniverseandeverythingisfortytwo"
AuthorizationCredentialsFile = "testdata/bearer.token"
AuthorizationType = "APIKEY"
BearerToken = AuthorizationCredentials
BearerTokenFile = AuthorizationCredentialsFile
MissingBearerTokenFile = "missing/bearer.token"
ExpectedBearer = "Bearer " + BearerToken
ExpectedAuthenticationCredentials = AuthorizationType + " " + BearerToken
ExpectedUsername = "arthurdent"
ExpectedPassword = "42"
)
var invalidHTTPClientConfigs = []struct {
httpClientConfigFile string
errMsg string
}{
{
httpClientConfigFile: "testdata/http.conf.bearer-token-and-file-set.bad.yml",
errMsg: "at most one of bearer_token & bearer_token_file must be configured",
},
{
httpClientConfigFile: "testdata/http.conf.empty.bad.yml",
errMsg: "at most one of basic_auth, bearer_token & bearer_token_file must be configured",
},
{
httpClientConfigFile: "testdata/http.conf.basic-auth.too-much.bad.yaml",
errMsg: "at most one of basic_auth password & password_file must be configured",
},
{
httpClientConfigFile: "testdata/http.conf.mix-bearer-and-creds.bad.yaml",
errMsg: "authorization is not compatible with bearer_token & bearer_token_file",
},
{
httpClientConfigFile: "testdata/http.conf.auth-creds-and-file-set.too-much.bad.yaml",
errMsg: "at most one of authorization credentials & credentials_file must be configured",
},
{
httpClientConfigFile: "testdata/http.conf.basic-auth-and-auth-creds.too-much.bad.yaml",
errMsg: "at most one of basic_auth & authorization must be configured",
},
{
httpClientConfigFile: "testdata/http.conf.auth-creds-no-basic.bad.yaml",
errMsg: `authorization type cannot be set to "basic", use "basic_auth" instead`,
},
}
func newTestServer(handler func(w http.ResponseWriter, r *http.Request)) (*httptest.Server, error) {
testServer := httptest.NewUnstartedServer(http.HandlerFunc(handler))
tlsCAChain, err := ioutil.ReadFile(TLSCAChainPath)
if err != nil {
return nil, fmt.Errorf("Can't read %s", TLSCAChainPath)
}
serverCertificate, err := tls.LoadX509KeyPair(ServerCertificatePath, ServerKeyPath)
if err != nil {
return nil, fmt.Errorf("Can't load X509 key pair %s - %s", ServerCertificatePath, ServerKeyPath)
}
rootCAs := x509.NewCertPool()
rootCAs.AppendCertsFromPEM(tlsCAChain)
testServer.TLS = &tls.Config{
Certificates: make([]tls.Certificate, 1),
RootCAs: rootCAs,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: rootCAs}
testServer.TLS.Certificates[0] = serverCertificate
testServer.StartTLS()
return testServer, nil
}
func TestNewClientFromConfig(t *testing.T) {
var newClientValidConfig = []struct {
clientConfig HTTPClientConfig
handler func(w http.ResponseWriter, r *http.Request)
}{
{
clientConfig: HTTPClientConfig{
TLSConfig: TLSConfig{
CAFile: "",
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: true},
},
handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, ExpectedMessage)
},
}, {
clientConfig: HTTPClientConfig{
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
},
handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, ExpectedMessage)
},
}, {
clientConfig: HTTPClientConfig{
BearerToken: BearerToken,
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
},
handler: func(w http.ResponseWriter, r *http.Request) {
bearer := r.Header.Get("Authorization")
if bearer != ExpectedBearer {
fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
ExpectedBearer, bearer)
} else {
fmt.Fprint(w, ExpectedMessage)
}
},
}, {
clientConfig: HTTPClientConfig{
BearerTokenFile: BearerTokenFile,
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
},
handler: func(w http.ResponseWriter, r *http.Request) {
bearer := r.Header.Get("Authorization")
if bearer != ExpectedBearer {
fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
ExpectedBearer, bearer)
} else {
fmt.Fprint(w, ExpectedMessage)
}
},
}, {
clientConfig: HTTPClientConfig{
Authorization: &Authorization{Credentials: BearerToken},
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
},
handler: func(w http.ResponseWriter, r *http.Request) {
bearer := r.Header.Get("Authorization")
if bearer != ExpectedBearer {
fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
ExpectedBearer, bearer)
} else {
fmt.Fprint(w, ExpectedMessage)
}
},
}, {
clientConfig: HTTPClientConfig{
Authorization: &Authorization{CredentialsFile: AuthorizationCredentialsFile, Type: AuthorizationType},
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
},
handler: func(w http.ResponseWriter, r *http.Request) {
bearer := r.Header.Get("Authorization")
if bearer != ExpectedAuthenticationCredentials {
fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
ExpectedAuthenticationCredentials, bearer)
} else {
fmt.Fprint(w, ExpectedMessage)
}
},
}, {
clientConfig: HTTPClientConfig{
Authorization: &Authorization{
Credentials: AuthorizationCredentials,
Type: AuthorizationType,
},
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
},
handler: func(w http.ResponseWriter, r *http.Request) {
bearer := r.Header.Get("Authorization")
if bearer != ExpectedAuthenticationCredentials {
fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
ExpectedAuthenticationCredentials, bearer)
} else {
fmt.Fprint(w, ExpectedMessage)
}
},
}, {
clientConfig: HTTPClientConfig{
Authorization: &Authorization{
CredentialsFile: BearerTokenFile,
},
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
},
handler: func(w http.ResponseWriter, r *http.Request) {
bearer := r.Header.Get("Authorization")
if bearer != ExpectedBearer {
fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
ExpectedBearer, bearer)
} else {
fmt.Fprint(w, ExpectedMessage)
}
},
}, {
clientConfig: HTTPClientConfig{
BasicAuth: &BasicAuth{
Username: ExpectedUsername,
Password: ExpectedPassword,
},
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
},
handler: func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok {
fmt.Fprintf(w, "The Authorization header wasn't set")
} else if ExpectedUsername != username {
fmt.Fprintf(w, "The expected username (%s) differs from the obtained username (%s).", ExpectedUsername, username)
} else if ExpectedPassword != password {
fmt.Fprintf(w, "The expected password (%s) differs from the obtained password (%s).", ExpectedPassword, password)
} else {
fmt.Fprint(w, ExpectedMessage)
}
},
},
}
for _, validConfig := range newClientValidConfig {
testServer, err := newTestServer(validConfig.handler)
if err != nil {
t.Fatal(err.Error())
}
defer testServer.Close()
err = validConfig.clientConfig.Validate()
if err != nil {
t.Fatal(err.Error())
}
client, err := NewClientFromConfig(validConfig.clientConfig, "test", false, true)
if err != nil {
t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig)
continue
}
response, err := client.Get(testServer.URL)
if err != nil {
t.Errorf("Can't connect to the test server using this config: %+v", validConfig.clientConfig)
continue
}
message, err := ioutil.ReadAll(response.Body)
response.Body.Close()
if err != nil {
t.Errorf("Can't read the server response body using this config: %+v", validConfig.clientConfig)
continue
}
trimMessage := strings.TrimSpace(string(message))
if ExpectedMessage != trimMessage {
t.Errorf("The expected message (%s) differs from the obtained message (%s) using this config: %+v",
ExpectedMessage, trimMessage, validConfig.clientConfig)
}
}
}
func TestNewClientFromInvalidConfig(t *testing.T) {
var newClientInvalidConfig = []struct {
clientConfig HTTPClientConfig
errorMsg string
}{
{
clientConfig: HTTPClientConfig{
TLSConfig: TLSConfig{
CAFile: MissingCA,
InsecureSkipVerify: true},
},
errorMsg: fmt.Sprintf("unable to load specified CA cert %s:", MissingCA),
},
{
clientConfig: HTTPClientConfig{
TLSConfig: TLSConfig{
CAFile: InvalidCA,
InsecureSkipVerify: true},
},
errorMsg: fmt.Sprintf("unable to use specified CA cert %s", InvalidCA),
},
}
for _, invalidConfig := range newClientInvalidConfig {
client, err := NewClientFromConfig(invalidConfig.clientConfig, "test", false, true)
if client != nil {
t.Errorf("A client instance was returned instead of nil using this config: %+v", invalidConfig.clientConfig)
}
if err == nil {
t.Errorf("No error was returned using this config: %+v", invalidConfig.clientConfig)
}
if !strings.Contains(err.Error(), invalidConfig.errorMsg) {
t.Errorf("Expected error %q does not contain %q", err.Error(), invalidConfig.errorMsg)
}
}
}
func TestMissingBearerAuthFile(t *testing.T) {
cfg := HTTPClientConfig{
BearerTokenFile: MissingBearerTokenFile,
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
}
handler := func(w http.ResponseWriter, r *http.Request) {
bearer := r.Header.Get("Authorization")
if bearer != ExpectedBearer {
fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
ExpectedBearer, bearer)
} else {
fmt.Fprint(w, ExpectedMessage)
}
}
testServer, err := newTestServer(handler)
if err != nil {
t.Fatal(err.Error())
}
defer testServer.Close()
client, err := NewClientFromConfig(cfg, "test", false, true)
if err != nil {
t.Fatal(err)
}
_, err = client.Get(testServer.URL)
if err == nil {
t.Fatal("No error is returned here")
}
if !strings.Contains(err.Error(), "unable to read authorization credentials file missing/bearer.token: open missing/bearer.token: no such file or directory") {
t.Fatal("wrong error message being returned")
}
}
func TestBearerAuthRoundTripper(t *testing.T) {
const (
newBearerToken = "goodbyeandthankyouforthefish"
)
fakeRoundTripper := NewRoundTripCheckRequest(func(req *http.Request) {
bearer := req.Header.Get("Authorization")
if bearer != ExpectedBearer {
t.Errorf("The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
ExpectedBearer, bearer)
}
}, nil, nil)
// Normal flow.
bearerAuthRoundTripper := NewAuthorizationCredentialsRoundTripper("Bearer", BearerToken, fakeRoundTripper)
request, _ := http.NewRequest("GET", "/hitchhiker", nil)
request.Header.Set("User-Agent", "Douglas Adams mind")
_, err := bearerAuthRoundTripper.RoundTrip(request)
if err != nil {
t.Errorf("unexpected error while executing RoundTrip: %s", err.Error())
}
// Should honor already Authorization header set.
bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsRoundTripper("Bearer", newBearerToken, fakeRoundTripper)
request, _ = http.NewRequest("GET", "/hitchhiker", nil)
request.Header.Set("Authorization", ExpectedBearer)
_, err = bearerAuthRoundTripperShouldNotModifyExistingAuthorization.RoundTrip(request)
if err != nil {
t.Errorf("unexpected error while executing RoundTrip: %s", err.Error())
}
}
func TestBearerAuthFileRoundTripper(t *testing.T) {
fakeRoundTripper := NewRoundTripCheckRequest(func(req *http.Request) {
bearer := req.Header.Get("Authorization")
if bearer != ExpectedBearer {
t.Errorf("The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
ExpectedBearer, bearer)
}
}, nil, nil)
// Normal flow.
bearerAuthRoundTripper := NewAuthorizationCredentialsFileRoundTripper("Bearer", BearerTokenFile, fakeRoundTripper)
request, _ := http.NewRequest("GET", "/hitchhiker", nil)
request.Header.Set("User-Agent", "Douglas Adams mind")
_, err := bearerAuthRoundTripper.RoundTrip(request)
if err != nil {
t.Errorf("unexpected error while executing RoundTrip: %s", err.Error())
}
// Should honor already Authorization header set.
bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsFileRoundTripper("Bearer", MissingBearerTokenFile, fakeRoundTripper)
request, _ = http.NewRequest("GET", "/hitchhiker", nil)
request.Header.Set("Authorization", ExpectedBearer)
_, err = bearerAuthRoundTripperShouldNotModifyExistingAuthorization.RoundTrip(request)
if err != nil {
t.Errorf("unexpected error while executing RoundTrip: %s", err.Error())
}
}
func TestTLSConfig(t *testing.T) {
configTLSConfig := TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "localhost",
InsecureSkipVerify: false}
tlsCAChain, err := ioutil.ReadFile(TLSCAChainPath)
if err != nil {
t.Fatalf("Can't read the CA certificate chain (%s)",
TLSCAChainPath)
}
rootCAs := x509.NewCertPool()
rootCAs.AppendCertsFromPEM(tlsCAChain)
expectedTLSConfig := &tls.Config{
RootCAs: rootCAs,
ServerName: configTLSConfig.ServerName,
InsecureSkipVerify: configTLSConfig.InsecureSkipVerify}
tlsConfig, err := NewTLSConfig(&configTLSConfig)
if err != nil {
t.Fatalf("Can't create a new TLS Config from a configuration (%s).", err)
}
clientCertificate, err := tls.LoadX509KeyPair(ClientCertificatePath, ClientKeyNoPassPath)
if err != nil {
t.Fatalf("Can't load the client key pair ('%s' and '%s'). Reason: %s",
ClientCertificatePath, ClientKeyNoPassPath, err)
}
cert, err := tlsConfig.GetClientCertificate(nil)
if err != nil {
t.Fatalf("unexpected error returned by tlsConfig.GetClientCertificate(): %s", err)
}
if !reflect.DeepEqual(cert, &clientCertificate) {
t.Fatalf("Unexpected client certificate result: \n\n%+v\n expected\n\n%+v", cert, clientCertificate)
}
// non-nil functions are never equal.
tlsConfig.GetClientCertificate = nil
if !reflect.DeepEqual(tlsConfig, expectedTLSConfig) {
t.Fatalf("Unexpected TLS Config result: \n\n%+v\n expected\n\n%+v", tlsConfig, expectedTLSConfig)
}
}
func TestTLSConfigEmpty(t *testing.T) {
configTLSConfig := TLSConfig{
InsecureSkipVerify: true,
}
expectedTLSConfig := &tls.Config{
InsecureSkipVerify: configTLSConfig.InsecureSkipVerify,
}
tlsConfig, err := NewTLSConfig(&configTLSConfig)
if err != nil {
t.Fatalf("Can't create a new TLS Config from a configuration (%s).", err)
}
if !reflect.DeepEqual(tlsConfig, expectedTLSConfig) {
t.Fatalf("Unexpected TLS Config result: \n\n%+v\n expected\n\n%+v", tlsConfig, expectedTLSConfig)
}
}
func TestTLSConfigInvalidCA(t *testing.T) {
var invalidTLSConfig = []struct {
configTLSConfig TLSConfig
errorMessage string
}{
{
configTLSConfig: TLSConfig{
CAFile: MissingCA,
CertFile: "",
KeyFile: "",
ServerName: "",
InsecureSkipVerify: false},
errorMessage: fmt.Sprintf("unable to load specified CA cert %s:", MissingCA),
}, {
configTLSConfig: TLSConfig{
CAFile: "",
CertFile: MissingCert,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath),
}, {
configTLSConfig: TLSConfig{
CAFile: "",
CertFile: ClientCertificatePath,
KeyFile: MissingKey,
ServerName: "",
InsecureSkipVerify: false},
errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey),
},
}
for _, anInvalididTLSConfig := range invalidTLSConfig {
tlsConfig, err := NewTLSConfig(&anInvalididTLSConfig.configTLSConfig)
if tlsConfig != nil && err == nil {
t.Errorf("The TLS Config could be created even with this %+v", anInvalididTLSConfig.configTLSConfig)
continue
}
if !strings.Contains(err.Error(), anInvalididTLSConfig.errorMessage) {
t.Errorf("The expected error should contain %s, but got %s", anInvalididTLSConfig.errorMessage, err)
}
}
}
func TestBasicAuthNoPassword(t *testing.T) {
cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.no-password.yaml")
if err != nil {
t.Fatalf("Error loading HTTP client config: %v", err)
}
client, err := NewClientFromConfig(*cfg, "test", false, true)
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
rt, ok := client.Transport.(*basicAuthRoundTripper)
if !ok {
t.Fatalf("Error casting to basic auth transport, %v", client.Transport)
}
if rt.username != "user" {
t.Errorf("Bad HTTP client username: %s", rt.username)
}
if string(rt.password) != "" {
t.Errorf("Expected empty HTTP client password: %s", rt.password)
}
if string(rt.passwordFile) != "" {
t.Errorf("Expected empty HTTP client passwordFile: %s", rt.passwordFile)
}
}
func TestBasicAuthNoUsername(t *testing.T) {
cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.no-username.yaml")
if err != nil {
t.Fatalf("Error loading HTTP client config: %v", err)
}
client, err := NewClientFromConfig(*cfg, "test", false, true)
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
rt, ok := client.Transport.(*basicAuthRoundTripper)
if !ok {
t.Fatalf("Error casting to basic auth transport, %v", client.Transport)
}
if rt.username != "" {
t.Errorf("Got unexpected username: %s", rt.username)
}
if string(rt.password) != "secret" {
t.Errorf("Unexpected HTTP client password: %s", string(rt.password))
}
if string(rt.passwordFile) != "" {
t.Errorf("Expected empty HTTP client passwordFile: %s", rt.passwordFile)
}
}
func TestBasicAuthPasswordFile(t *testing.T) {
cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.good.yaml")
if err != nil {
t.Fatalf("Error loading HTTP client config: %v", err)
}
client, err := NewClientFromConfig(*cfg, "test", false, true)
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
rt, ok := client.Transport.(*basicAuthRoundTripper)
if !ok {
t.Fatalf("Error casting to basic auth transport, %v", client.Transport)
}
if rt.username != "user" {
t.Errorf("Bad HTTP client username: %s", rt.username)
}
if string(rt.password) != "" {
t.Errorf("Bad HTTP client password: %s", rt.password)
}
if string(rt.passwordFile) != "testdata/basic-auth-password" {
t.Errorf("Bad HTTP client passwordFile: %s", rt.passwordFile)
}
}
func getCertificateBlobs(t *testing.T) map[string][]byte {
files := []string{
TLSCAChainPath,
ClientCertificatePath,
ClientKeyNoPassPath,
ServerCertificatePath,
ServerKeyPath,
WrongClientCertPath,
WrongClientKeyPath,
EmptyFile,
}
bs := make(map[string][]byte, len(files)+1)
for _, f := range files {
b, err := ioutil.ReadFile(f)
if err != nil {
t.Fatal(err)
}
bs[f] = b
}
return bs
}
func writeCertificate(bs map[string][]byte, src string, dst string) {
b, ok := bs[src]
if !ok {
panic(fmt.Sprintf("Couldn't find %q in bs", src))
}
if err := ioutil.WriteFile(dst, b, 0664); err != nil {
panic(err)
}
}
func TestTLSRoundTripper(t *testing.T) {
bs := getCertificateBlobs(t)
tmpDir, err := ioutil.TempDir("", "tlsroundtripper")
if err != nil {
t.Fatal("Failed to create tmp dir", err)
}
defer os.RemoveAll(tmpDir)
ca, cert, key := filepath.Join(tmpDir, "ca"), filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key")
handler := func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, ExpectedMessage)
}
testServer, err := newTestServer(handler)
if err != nil {
t.Fatal(err.Error())
}
defer testServer.Close()
testCases := []struct {
ca string
cert string
key string
errMsg string
}{
{
// Valid certs.
ca: TLSCAChainPath,
cert: ClientCertificatePath,
key: ClientKeyNoPassPath,
},
{
// CA not matching.
ca: ClientCertificatePath,
cert: ClientCertificatePath,
key: ClientKeyNoPassPath,
errMsg: "certificate signed by unknown authority",
},
{
// Invalid client cert+key.
ca: TLSCAChainPath,
cert: WrongClientCertPath,
key: WrongClientKeyPath,
errMsg: "remote error: tls",
},
{
// CA file empty
ca: EmptyFile,
cert: ClientCertificatePath,
key: ClientKeyNoPassPath,
errMsg: "unable to use specified CA cert",
},
{
// cert file empty
ca: TLSCAChainPath,
cert: EmptyFile,
key: ClientKeyNoPassPath,
errMsg: "failed to find any PEM data in certificate input",
},
{
// key file empty
ca: TLSCAChainPath,
cert: ClientCertificatePath,
key: EmptyFile,
errMsg: "failed to find any PEM data in key input",
},
{
// Valid certs again.
ca: TLSCAChainPath,
cert: ClientCertificatePath,
key: ClientKeyNoPassPath,
},
}
cfg := HTTPClientConfig{
TLSConfig: TLSConfig{
CAFile: ca,
CertFile: cert,
KeyFile: key,
InsecureSkipVerify: false},
}
var c *http.Client
for i, tc := range testCases {
tc := tc
t.Run(strconv.Itoa(i), func(t *testing.T) {
writeCertificate(bs, tc.ca, ca)
writeCertificate(bs, tc.cert, cert)
writeCertificate(bs, tc.key, key)
if c == nil {
c, err = NewClientFromConfig(cfg, "test", false, true)
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
}
req, err := http.NewRequest(http.MethodGet, testServer.URL, nil)
if err != nil {
t.Fatalf("Error creating HTTP request: %v", err)
}
r, err := c.Do(req)
if len(tc.errMsg) > 0 {
if err == nil {
r.Body.Close()
t.Fatalf("Could connect to the test server.")
}
if !strings.Contains(err.Error(), tc.errMsg) {
t.Fatalf("Expected error message to contain %q, got %q", tc.errMsg, err)
}
return
}
if err != nil {
t.Fatalf("Can't connect to the test server")
}
b, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
t.Errorf("Can't read the server response body")
}
got := strings.TrimSpace(string(b))
if ExpectedMessage != got {
t.Errorf("The expected message %q differs from the obtained message %q", ExpectedMessage, got)
}
})
}
}
func TestTLSRoundTripperRaces(t *testing.T) {
bs := getCertificateBlobs(t)
tmpDir, err := ioutil.TempDir("", "tlsroundtripper")
if err != nil {
t.Fatal("Failed to create tmp dir", err)
}
defer os.RemoveAll(tmpDir)
ca, cert, key := filepath.Join(tmpDir, "ca"), filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key")
handler := func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, ExpectedMessage)
}
testServer, err := newTestServer(handler)
if err != nil {
t.Fatal(err.Error())
}
defer testServer.Close()
cfg := HTTPClientConfig{
TLSConfig: TLSConfig{
CAFile: ca,
CertFile: cert,
KeyFile: key,
InsecureSkipVerify: false},
}
var c *http.Client
writeCertificate(bs, TLSCAChainPath, ca)
writeCertificate(bs, ClientCertificatePath, cert)
writeCertificate(bs, ClientKeyNoPassPath, key)
c, err = NewClientFromConfig(cfg, "test", false, true)
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
var wg sync.WaitGroup
ch := make(chan struct{})
var total, ok int64
// Spawn 10 Go routines polling the server concurrently.
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-ch:
return
default:
atomic.AddInt64(&total, 1)
r, err := c.Get(testServer.URL)
if err == nil {
r.Body.Close()
atomic.AddInt64(&ok, 1)
}
}
}
}()
}
// Change the CA file every 10ms for 1 second.
wg.Add(1)
go func() {
defer wg.Done()
i := 0
for {
tick := time.NewTicker(10 * time.Millisecond)
<-tick.C
if i%2 == 0 {
writeCertificate(bs, ClientCertificatePath, ca)
} else {
writeCertificate(bs, TLSCAChainPath, ca)
}
i++
if i > 100 {
close(ch)
return
}
}
}()
wg.Wait()
if ok == total {
t.Fatalf("Expecting some requests to fail but got %d/%d successful requests", ok, total)
}
}
func TestHideHTTPClientConfigSecrets(t *testing.T) {
c, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml")
if err != nil {
t.Errorf("Error parsing %s: %s", "testdata/http.conf.good.yml", err)
}
// String method must not reveal authentication credentials.
s := c.String()
if strings.Contains(s, "mysecret") {
t.Fatal("http client config's String method reveals authentication credentials.")
}
}
func TestValidateHTTPConfig(t *testing.T) {
cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml")
if err != nil {
t.Errorf("Error loading HTTP client config: %v", err)
}
err = cfg.Validate()
if err != nil {
t.Fatalf("Error validating %s: %s", "testdata/http.conf.good.yml", err)
}
}
func TestInvalidHTTPConfigs(t *testing.T) {
for _, ee := range invalidHTTPClientConfigs {
_, _, err := LoadHTTPConfigFile(ee.httpClientConfigFile)
if err == nil {
t.Error("Expected error with config but got none")
continue
}
if !strings.Contains(err.Error(), ee.errMsg) {
t.Errorf("Expected error for invalid HTTP client configuration to contain %q but got: %s", ee.errMsg, err)
}
}
}
// LoadHTTPConfig parses the YAML input s into a HTTPClientConfig.
func LoadHTTPConfig(s string) (*HTTPClientConfig, error) {
cfg := &HTTPClientConfig{}
err := yaml.UnmarshalStrict([]byte(s), cfg)
if err != nil {
return nil, err
}
return cfg, nil
}
// LoadHTTPConfigFile parses the given YAML file into a HTTPClientConfig.
func LoadHTTPConfigFile(filename string) (*HTTPClientConfig, []byte, error) {
content, err := ioutil.ReadFile(filename)
if err != nil {
return nil, nil, err
}
cfg, err := LoadHTTPConfig(string(content))
if err != nil {
return nil, nil, err
}
return cfg, content, nil
}
type roundTrip struct {
theResponse *http.Response
theError error
}
func (rt *roundTrip) RoundTrip(r *http.Request) (*http.Response, error) {
return rt.theResponse, rt.theError
}
type roundTripCheckRequest struct {
checkRequest func(*http.Request)
roundTrip
}
func (rt *roundTripCheckRequest) RoundTrip(r *http.Request) (*http.Response, error) {
rt.checkRequest(r)
return rt.theResponse, rt.theError
}
// NewRoundTripCheckRequest creates a new instance of a type that implements http.RoundTripper,
// which before returning theResponse and theError, executes checkRequest against a http.Request.
func NewRoundTripCheckRequest(checkRequest func(*http.Request), theResponse *http.Response, theError error) http.RoundTripper {
return &roundTripCheckRequest{
checkRequest: checkRequest,
roundTrip: roundTrip{
theResponse: theResponse,
theError: theError}}
}