| // Copyright 2015 The Prometheus Authors |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // +build go1.8 |
| |
| package config |
| |
| import ( |
| "context" |
| "crypto/tls" |
| "crypto/x509" |
| "encoding/json" |
| "errors" |
| "fmt" |
| "io/ioutil" |
| "net" |
| "net/http" |
| "net/http/httptest" |
| "net/url" |
| "os" |
| "path/filepath" |
| "reflect" |
| "strconv" |
| "strings" |
| "sync" |
| "sync/atomic" |
| "testing" |
| "time" |
| |
| yaml "gopkg.in/yaml.v2" |
| ) |
| |
| const ( |
| TLSCAChainPath = "testdata/tls-ca-chain.pem" |
| ServerCertificatePath = "testdata/server.crt" |
| ServerKeyPath = "testdata/server.key" |
| ClientCertificatePath = "testdata/client.crt" |
| ClientKeyNoPassPath = "testdata/client-no-pass.key" |
| InvalidCA = "testdata/client-no-pass.key" |
| WrongClientCertPath = "testdata/self-signed-client.crt" |
| WrongClientKeyPath = "testdata/self-signed-client.key" |
| EmptyFile = "testdata/empty" |
| MissingCA = "missing/ca.crt" |
| MissingCert = "missing/cert.crt" |
| MissingKey = "missing/secret.key" |
| |
| ExpectedMessage = "I'm here to serve you!!!" |
| ExpectedError = "expected error" |
| AuthorizationCredentials = "theanswertothegreatquestionoflifetheuniverseandeverythingisfortytwo" |
| AuthorizationCredentialsFile = "testdata/bearer.token" |
| AuthorizationType = "APIKEY" |
| BearerToken = AuthorizationCredentials |
| BearerTokenFile = AuthorizationCredentialsFile |
| MissingBearerTokenFile = "missing/bearer.token" |
| ExpectedBearer = "Bearer " + BearerToken |
| ExpectedAuthenticationCredentials = AuthorizationType + " " + BearerToken |
| ExpectedUsername = "arthurdent" |
| ExpectedPassword = "42" |
| ) |
| |
| var invalidHTTPClientConfigs = []struct { |
| httpClientConfigFile string |
| errMsg string |
| }{ |
| { |
| httpClientConfigFile: "testdata/http.conf.bearer-token-and-file-set.bad.yml", |
| errMsg: "at most one of bearer_token & bearer_token_file must be configured", |
| }, |
| { |
| httpClientConfigFile: "testdata/http.conf.empty.bad.yml", |
| errMsg: "at most one of basic_auth, oauth2, bearer_token & bearer_token_file must be configured", |
| }, |
| { |
| httpClientConfigFile: "testdata/http.conf.basic-auth.too-much.bad.yaml", |
| errMsg: "at most one of basic_auth password & password_file must be configured", |
| }, |
| { |
| httpClientConfigFile: "testdata/http.conf.mix-bearer-and-creds.bad.yaml", |
| errMsg: "authorization is not compatible with bearer_token & bearer_token_file", |
| }, |
| { |
| httpClientConfigFile: "testdata/http.conf.auth-creds-and-file-set.too-much.bad.yaml", |
| errMsg: "at most one of authorization credentials & credentials_file must be configured", |
| }, |
| { |
| httpClientConfigFile: "testdata/http.conf.basic-auth-and-auth-creds.too-much.bad.yaml", |
| errMsg: "at most one of basic_auth, oauth2 & authorization must be configured", |
| }, |
| { |
| httpClientConfigFile: "testdata/http.conf.basic-auth-and-oauth2.too-much.bad.yaml", |
| errMsg: "at most one of basic_auth, oauth2 & authorization must be configured", |
| }, |
| { |
| httpClientConfigFile: "testdata/http.conf.auth-creds-no-basic.bad.yaml", |
| errMsg: `authorization type cannot be set to "basic", use "basic_auth" instead`, |
| }, |
| { |
| httpClientConfigFile: "testdata/http.conf.oauth2-secret-and-file-set.bad.yml", |
| errMsg: "at most one of oauth2 client_secret & client_secret_file must be configured", |
| }, |
| { |
| httpClientConfigFile: "testdata/http.conf.oauth2-no-client-id.bad.yaml", |
| errMsg: "oauth2 client_id must be configured", |
| }, |
| { |
| httpClientConfigFile: "testdata/http.conf.oauth2-no-client-secret.bad.yaml", |
| errMsg: "either oauth2 client_secret or client_secret_file must be configured", |
| }, |
| { |
| httpClientConfigFile: "testdata/http.conf.oauth2-no-token-url.bad.yaml", |
| errMsg: "oauth2 token_url must be configured", |
| }, |
| } |
| |
| func newTestServer(handler func(w http.ResponseWriter, r *http.Request)) (*httptest.Server, error) { |
| testServer := httptest.NewUnstartedServer(http.HandlerFunc(handler)) |
| |
| tlsCAChain, err := ioutil.ReadFile(TLSCAChainPath) |
| if err != nil { |
| return nil, fmt.Errorf("Can't read %s", TLSCAChainPath) |
| } |
| serverCertificate, err := tls.LoadX509KeyPair(ServerCertificatePath, ServerKeyPath) |
| if err != nil { |
| return nil, fmt.Errorf("Can't load X509 key pair %s - %s", ServerCertificatePath, ServerKeyPath) |
| } |
| |
| rootCAs := x509.NewCertPool() |
| rootCAs.AppendCertsFromPEM(tlsCAChain) |
| |
| testServer.TLS = &tls.Config{ |
| Certificates: make([]tls.Certificate, 1), |
| RootCAs: rootCAs, |
| ClientAuth: tls.RequireAndVerifyClientCert, |
| ClientCAs: rootCAs} |
| testServer.TLS.Certificates[0] = serverCertificate |
| |
| testServer.StartTLS() |
| |
| return testServer, nil |
| } |
| |
| func TestNewClientFromConfig(t *testing.T) { |
| var newClientValidConfig = []struct { |
| clientConfig HTTPClientConfig |
| handler func(w http.ResponseWriter, r *http.Request) |
| }{ |
| { |
| clientConfig: HTTPClientConfig{ |
| TLSConfig: TLSConfig{ |
| CAFile: "", |
| CertFile: ClientCertificatePath, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "", |
| InsecureSkipVerify: true}, |
| }, |
| handler: func(w http.ResponseWriter, r *http.Request) { |
| fmt.Fprint(w, ExpectedMessage) |
| }, |
| }, { |
| clientConfig: HTTPClientConfig{ |
| TLSConfig: TLSConfig{ |
| CAFile: TLSCAChainPath, |
| CertFile: ClientCertificatePath, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| }, |
| handler: func(w http.ResponseWriter, r *http.Request) { |
| fmt.Fprint(w, ExpectedMessage) |
| }, |
| }, { |
| clientConfig: HTTPClientConfig{ |
| BearerToken: BearerToken, |
| TLSConfig: TLSConfig{ |
| CAFile: TLSCAChainPath, |
| CertFile: ClientCertificatePath, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| }, |
| handler: func(w http.ResponseWriter, r *http.Request) { |
| bearer := r.Header.Get("Authorization") |
| if bearer != ExpectedBearer { |
| fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", |
| ExpectedBearer, bearer) |
| } else { |
| fmt.Fprint(w, ExpectedMessage) |
| } |
| }, |
| }, { |
| clientConfig: HTTPClientConfig{ |
| BearerTokenFile: BearerTokenFile, |
| TLSConfig: TLSConfig{ |
| CAFile: TLSCAChainPath, |
| CertFile: ClientCertificatePath, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| }, |
| handler: func(w http.ResponseWriter, r *http.Request) { |
| bearer := r.Header.Get("Authorization") |
| if bearer != ExpectedBearer { |
| fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", |
| ExpectedBearer, bearer) |
| } else { |
| fmt.Fprint(w, ExpectedMessage) |
| } |
| }, |
| }, { |
| clientConfig: HTTPClientConfig{ |
| Authorization: &Authorization{Credentials: BearerToken}, |
| TLSConfig: TLSConfig{ |
| CAFile: TLSCAChainPath, |
| CertFile: ClientCertificatePath, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| }, |
| handler: func(w http.ResponseWriter, r *http.Request) { |
| bearer := r.Header.Get("Authorization") |
| if bearer != ExpectedBearer { |
| fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", |
| ExpectedBearer, bearer) |
| } else { |
| fmt.Fprint(w, ExpectedMessage) |
| } |
| }, |
| }, { |
| clientConfig: HTTPClientConfig{ |
| Authorization: &Authorization{CredentialsFile: AuthorizationCredentialsFile, Type: AuthorizationType}, |
| TLSConfig: TLSConfig{ |
| CAFile: TLSCAChainPath, |
| CertFile: ClientCertificatePath, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| }, |
| handler: func(w http.ResponseWriter, r *http.Request) { |
| bearer := r.Header.Get("Authorization") |
| if bearer != ExpectedAuthenticationCredentials { |
| fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", |
| ExpectedAuthenticationCredentials, bearer) |
| } else { |
| fmt.Fprint(w, ExpectedMessage) |
| } |
| }, |
| }, { |
| clientConfig: HTTPClientConfig{ |
| Authorization: &Authorization{ |
| Credentials: AuthorizationCredentials, |
| Type: AuthorizationType, |
| }, |
| TLSConfig: TLSConfig{ |
| CAFile: TLSCAChainPath, |
| CertFile: ClientCertificatePath, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| }, |
| handler: func(w http.ResponseWriter, r *http.Request) { |
| bearer := r.Header.Get("Authorization") |
| if bearer != ExpectedAuthenticationCredentials { |
| fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", |
| ExpectedAuthenticationCredentials, bearer) |
| } else { |
| fmt.Fprint(w, ExpectedMessage) |
| } |
| }, |
| }, { |
| clientConfig: HTTPClientConfig{ |
| Authorization: &Authorization{ |
| CredentialsFile: BearerTokenFile, |
| }, |
| TLSConfig: TLSConfig{ |
| CAFile: TLSCAChainPath, |
| CertFile: ClientCertificatePath, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| }, |
| handler: func(w http.ResponseWriter, r *http.Request) { |
| bearer := r.Header.Get("Authorization") |
| if bearer != ExpectedBearer { |
| fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", |
| ExpectedBearer, bearer) |
| } else { |
| fmt.Fprint(w, ExpectedMessage) |
| } |
| }, |
| }, { |
| clientConfig: HTTPClientConfig{ |
| BasicAuth: &BasicAuth{ |
| Username: ExpectedUsername, |
| Password: ExpectedPassword, |
| }, |
| TLSConfig: TLSConfig{ |
| CAFile: TLSCAChainPath, |
| CertFile: ClientCertificatePath, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| }, |
| handler: func(w http.ResponseWriter, r *http.Request) { |
| username, password, ok := r.BasicAuth() |
| if !ok { |
| fmt.Fprintf(w, "The Authorization header wasn't set") |
| } else if ExpectedUsername != username { |
| fmt.Fprintf(w, "The expected username (%s) differs from the obtained username (%s).", ExpectedUsername, username) |
| } else if ExpectedPassword != password { |
| fmt.Fprintf(w, "The expected password (%s) differs from the obtained password (%s).", ExpectedPassword, password) |
| } else { |
| fmt.Fprint(w, ExpectedMessage) |
| } |
| }, |
| }, { |
| clientConfig: HTTPClientConfig{ |
| FollowRedirects: true, |
| TLSConfig: TLSConfig{ |
| CAFile: TLSCAChainPath, |
| CertFile: ClientCertificatePath, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| }, |
| handler: func(w http.ResponseWriter, r *http.Request) { |
| switch r.URL.Path { |
| case "/redirected": |
| fmt.Fprint(w, ExpectedMessage) |
| default: |
| w.Header().Set("Location", "/redirected") |
| w.WriteHeader(http.StatusFound) |
| fmt.Fprint(w, "It should follow the redirect.") |
| } |
| }, |
| }, { |
| clientConfig: HTTPClientConfig{ |
| FollowRedirects: false, |
| TLSConfig: TLSConfig{ |
| CAFile: TLSCAChainPath, |
| CertFile: ClientCertificatePath, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| }, |
| handler: func(w http.ResponseWriter, r *http.Request) { |
| switch r.URL.Path { |
| case "/redirected": |
| fmt.Fprint(w, "The redirection was followed.") |
| default: |
| w.Header().Set("Location", "/redirected") |
| w.WriteHeader(http.StatusFound) |
| fmt.Fprint(w, ExpectedMessage) |
| } |
| }, |
| }, |
| } |
| |
| for _, validConfig := range newClientValidConfig { |
| testServer, err := newTestServer(validConfig.handler) |
| if err != nil { |
| t.Fatal(err.Error()) |
| } |
| defer testServer.Close() |
| |
| err = validConfig.clientConfig.Validate() |
| if err != nil { |
| t.Fatal(err.Error()) |
| } |
| client, err := NewClientFromConfig(validConfig.clientConfig, "test") |
| if err != nil { |
| t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig) |
| continue |
| } |
| response, err := client.Get(testServer.URL) |
| if err != nil { |
| t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err) |
| continue |
| } |
| |
| message, err := ioutil.ReadAll(response.Body) |
| response.Body.Close() |
| if err != nil { |
| t.Errorf("Can't read the server response body using this config: %+v", validConfig.clientConfig) |
| continue |
| } |
| |
| trimMessage := strings.TrimSpace(string(message)) |
| if ExpectedMessage != trimMessage { |
| t.Errorf("The expected message (%s) differs from the obtained message (%s) using this config: %+v", |
| ExpectedMessage, trimMessage, validConfig.clientConfig) |
| } |
| } |
| } |
| |
| func TestNewClientFromInvalidConfig(t *testing.T) { |
| var newClientInvalidConfig = []struct { |
| clientConfig HTTPClientConfig |
| errorMsg string |
| }{ |
| { |
| clientConfig: HTTPClientConfig{ |
| TLSConfig: TLSConfig{ |
| CAFile: MissingCA, |
| InsecureSkipVerify: true}, |
| }, |
| errorMsg: fmt.Sprintf("unable to load specified CA cert %s:", MissingCA), |
| }, |
| { |
| clientConfig: HTTPClientConfig{ |
| TLSConfig: TLSConfig{ |
| CAFile: InvalidCA, |
| InsecureSkipVerify: true}, |
| }, |
| errorMsg: fmt.Sprintf("unable to use specified CA cert %s", InvalidCA), |
| }, |
| } |
| |
| for _, invalidConfig := range newClientInvalidConfig { |
| client, err := NewClientFromConfig(invalidConfig.clientConfig, "test") |
| if client != nil { |
| t.Errorf("A client instance was returned instead of nil using this config: %+v", invalidConfig.clientConfig) |
| } |
| if err == nil { |
| t.Errorf("No error was returned using this config: %+v", invalidConfig.clientConfig) |
| } |
| if !strings.Contains(err.Error(), invalidConfig.errorMsg) { |
| t.Errorf("Expected error %q does not contain %q", err.Error(), invalidConfig.errorMsg) |
| } |
| } |
| } |
| |
| func TestCustomDialContextFunc(t *testing.T) { |
| dialFn := func(_ context.Context, _, _ string) (net.Conn, error) { |
| return nil, errors.New(ExpectedError) |
| } |
| |
| cfg := HTTPClientConfig{} |
| client, err := NewClientFromConfig(cfg, "test", WithDialContextFunc(dialFn)) |
| if err != nil { |
| t.Fatalf("Can't create a client from this config: %+v", cfg) |
| } |
| |
| _, err = client.Get("http://localhost") |
| if err == nil || !strings.Contains(err.Error(), ExpectedError) { |
| t.Errorf("Expected error %q but got %q", ExpectedError, err) |
| } |
| } |
| |
| func TestCustomIdleConnTimeout(t *testing.T) { |
| timeout := time.Second * 5 |
| |
| cfg := HTTPClientConfig{} |
| rt, err := NewRoundTripperFromConfig(cfg, "test", WithIdleConnTimeout(timeout)) |
| if err != nil { |
| t.Fatalf("Can't create a round-tripper from this config: %+v", cfg) |
| } |
| |
| transport, ok := rt.(*http.Transport) |
| if !ok { |
| t.Fatalf("Unexpected transport: %+v", transport) |
| } |
| |
| if transport.IdleConnTimeout != timeout { |
| t.Fatalf("Unexpected idle connection timeout: %+v", timeout) |
| } |
| } |
| |
| func TestMissingBearerAuthFile(t *testing.T) { |
| cfg := HTTPClientConfig{ |
| BearerTokenFile: MissingBearerTokenFile, |
| TLSConfig: TLSConfig{ |
| CAFile: TLSCAChainPath, |
| CertFile: ClientCertificatePath, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| } |
| handler := func(w http.ResponseWriter, r *http.Request) { |
| bearer := r.Header.Get("Authorization") |
| if bearer != ExpectedBearer { |
| fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", |
| ExpectedBearer, bearer) |
| } else { |
| fmt.Fprint(w, ExpectedMessage) |
| } |
| } |
| |
| testServer, err := newTestServer(handler) |
| if err != nil { |
| t.Fatal(err.Error()) |
| } |
| defer testServer.Close() |
| |
| client, err := NewClientFromConfig(cfg, "test") |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| _, err = client.Get(testServer.URL) |
| if err == nil { |
| t.Fatal("No error is returned here") |
| } |
| |
| if !strings.Contains(err.Error(), "unable to read authorization credentials file missing/bearer.token: open missing/bearer.token: no such file or directory") { |
| t.Fatal("wrong error message being returned") |
| } |
| } |
| |
| func TestBearerAuthRoundTripper(t *testing.T) { |
| const ( |
| newBearerToken = "goodbyeandthankyouforthefish" |
| ) |
| |
| fakeRoundTripper := NewRoundTripCheckRequest(func(req *http.Request) { |
| bearer := req.Header.Get("Authorization") |
| if bearer != ExpectedBearer { |
| t.Errorf("The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", |
| ExpectedBearer, bearer) |
| } |
| }, nil, nil) |
| |
| // Normal flow. |
| bearerAuthRoundTripper := NewAuthorizationCredentialsRoundTripper("Bearer", BearerToken, fakeRoundTripper) |
| request, _ := http.NewRequest("GET", "/hitchhiker", nil) |
| request.Header.Set("User-Agent", "Douglas Adams mind") |
| _, err := bearerAuthRoundTripper.RoundTrip(request) |
| if err != nil { |
| t.Errorf("unexpected error while executing RoundTrip: %s", err.Error()) |
| } |
| |
| // Should honor already Authorization header set. |
| bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsRoundTripper("Bearer", newBearerToken, fakeRoundTripper) |
| request, _ = http.NewRequest("GET", "/hitchhiker", nil) |
| request.Header.Set("Authorization", ExpectedBearer) |
| _, err = bearerAuthRoundTripperShouldNotModifyExistingAuthorization.RoundTrip(request) |
| if err != nil { |
| t.Errorf("unexpected error while executing RoundTrip: %s", err.Error()) |
| } |
| } |
| |
| func TestBearerAuthFileRoundTripper(t *testing.T) { |
| fakeRoundTripper := NewRoundTripCheckRequest(func(req *http.Request) { |
| bearer := req.Header.Get("Authorization") |
| if bearer != ExpectedBearer { |
| t.Errorf("The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", |
| ExpectedBearer, bearer) |
| } |
| }, nil, nil) |
| |
| // Normal flow. |
| bearerAuthRoundTripper := NewAuthorizationCredentialsFileRoundTripper("Bearer", BearerTokenFile, fakeRoundTripper) |
| request, _ := http.NewRequest("GET", "/hitchhiker", nil) |
| request.Header.Set("User-Agent", "Douglas Adams mind") |
| _, err := bearerAuthRoundTripper.RoundTrip(request) |
| if err != nil { |
| t.Errorf("unexpected error while executing RoundTrip: %s", err.Error()) |
| } |
| |
| // Should honor already Authorization header set. |
| bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsFileRoundTripper("Bearer", MissingBearerTokenFile, fakeRoundTripper) |
| request, _ = http.NewRequest("GET", "/hitchhiker", nil) |
| request.Header.Set("Authorization", ExpectedBearer) |
| _, err = bearerAuthRoundTripperShouldNotModifyExistingAuthorization.RoundTrip(request) |
| if err != nil { |
| t.Errorf("unexpected error while executing RoundTrip: %s", err.Error()) |
| } |
| } |
| |
| func TestTLSConfig(t *testing.T) { |
| configTLSConfig := TLSConfig{ |
| CAFile: TLSCAChainPath, |
| CertFile: ClientCertificatePath, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "localhost", |
| InsecureSkipVerify: false} |
| |
| tlsCAChain, err := ioutil.ReadFile(TLSCAChainPath) |
| if err != nil { |
| t.Fatalf("Can't read the CA certificate chain (%s)", |
| TLSCAChainPath) |
| } |
| rootCAs := x509.NewCertPool() |
| rootCAs.AppendCertsFromPEM(tlsCAChain) |
| |
| expectedTLSConfig := &tls.Config{ |
| RootCAs: rootCAs, |
| ServerName: configTLSConfig.ServerName, |
| InsecureSkipVerify: configTLSConfig.InsecureSkipVerify} |
| |
| tlsConfig, err := NewTLSConfig(&configTLSConfig) |
| if err != nil { |
| t.Fatalf("Can't create a new TLS Config from a configuration (%s).", err) |
| } |
| |
| clientCertificate, err := tls.LoadX509KeyPair(ClientCertificatePath, ClientKeyNoPassPath) |
| if err != nil { |
| t.Fatalf("Can't load the client key pair ('%s' and '%s'). Reason: %s", |
| ClientCertificatePath, ClientKeyNoPassPath, err) |
| } |
| cert, err := tlsConfig.GetClientCertificate(nil) |
| if err != nil { |
| t.Fatalf("unexpected error returned by tlsConfig.GetClientCertificate(): %s", err) |
| } |
| if !reflect.DeepEqual(cert, &clientCertificate) { |
| t.Fatalf("Unexpected client certificate result: \n\n%+v\n expected\n\n%+v", cert, clientCertificate) |
| } |
| |
| // tlsConfig.rootCAs.LazyCerts contains functions getCert() in go 1.16, which are |
| // never equal. Compare the Subjects instead. |
| if !reflect.DeepEqual(tlsConfig.RootCAs.Subjects(), expectedTLSConfig.RootCAs.Subjects()) { |
| t.Fatalf("Unexpected RootCAs result: \n\n%+v\n expected\n\n%+v", tlsConfig.RootCAs.Subjects(), expectedTLSConfig.RootCAs.Subjects()) |
| } |
| tlsConfig.RootCAs = nil |
| expectedTLSConfig.RootCAs = nil |
| |
| // Non-nil functions are never equal. |
| tlsConfig.GetClientCertificate = nil |
| |
| if !reflect.DeepEqual(tlsConfig, expectedTLSConfig) { |
| t.Fatalf("Unexpected TLS Config result: \n\n%+v\n expected\n\n%+v", tlsConfig, expectedTLSConfig) |
| } |
| } |
| |
| func TestTLSConfigEmpty(t *testing.T) { |
| configTLSConfig := TLSConfig{ |
| InsecureSkipVerify: true, |
| } |
| |
| expectedTLSConfig := &tls.Config{ |
| InsecureSkipVerify: configTLSConfig.InsecureSkipVerify, |
| } |
| |
| tlsConfig, err := NewTLSConfig(&configTLSConfig) |
| if err != nil { |
| t.Fatalf("Can't create a new TLS Config from a configuration (%s).", err) |
| } |
| |
| if !reflect.DeepEqual(tlsConfig, expectedTLSConfig) { |
| t.Fatalf("Unexpected TLS Config result: \n\n%+v\n expected\n\n%+v", tlsConfig, expectedTLSConfig) |
| } |
| } |
| |
| func TestTLSConfigInvalidCA(t *testing.T) { |
| var invalidTLSConfig = []struct { |
| configTLSConfig TLSConfig |
| errorMessage string |
| }{ |
| { |
| configTLSConfig: TLSConfig{ |
| CAFile: MissingCA, |
| CertFile: "", |
| KeyFile: "", |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| errorMessage: fmt.Sprintf("unable to load specified CA cert %s:", MissingCA), |
| }, { |
| configTLSConfig: TLSConfig{ |
| CAFile: "", |
| CertFile: MissingCert, |
| KeyFile: ClientKeyNoPassPath, |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath), |
| }, { |
| configTLSConfig: TLSConfig{ |
| CAFile: "", |
| CertFile: ClientCertificatePath, |
| KeyFile: MissingKey, |
| ServerName: "", |
| InsecureSkipVerify: false}, |
| errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey), |
| }, |
| } |
| |
| for _, anInvalididTLSConfig := range invalidTLSConfig { |
| tlsConfig, err := NewTLSConfig(&anInvalididTLSConfig.configTLSConfig) |
| if tlsConfig != nil && err == nil { |
| t.Errorf("The TLS Config could be created even with this %+v", anInvalididTLSConfig.configTLSConfig) |
| continue |
| } |
| if !strings.Contains(err.Error(), anInvalididTLSConfig.errorMessage) { |
| t.Errorf("The expected error should contain %s, but got %s", anInvalididTLSConfig.errorMessage, err) |
| } |
| } |
| } |
| |
| func TestBasicAuthNoPassword(t *testing.T) { |
| cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.no-password.yaml") |
| if err != nil { |
| t.Fatalf("Error loading HTTP client config: %v", err) |
| } |
| client, err := NewClientFromConfig(*cfg, "test") |
| if err != nil { |
| t.Fatalf("Error creating HTTP Client: %v", err) |
| } |
| |
| rt, ok := client.Transport.(*basicAuthRoundTripper) |
| if !ok { |
| t.Fatalf("Error casting to basic auth transport, %v", client.Transport) |
| } |
| |
| if rt.username != "user" { |
| t.Errorf("Bad HTTP client username: %s", rt.username) |
| } |
| if string(rt.password) != "" { |
| t.Errorf("Expected empty HTTP client password: %s", rt.password) |
| } |
| if string(rt.passwordFile) != "" { |
| t.Errorf("Expected empty HTTP client passwordFile: %s", rt.passwordFile) |
| } |
| } |
| |
| func TestBasicAuthNoUsername(t *testing.T) { |
| cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.no-username.yaml") |
| if err != nil { |
| t.Fatalf("Error loading HTTP client config: %v", err) |
| } |
| client, err := NewClientFromConfig(*cfg, "test") |
| if err != nil { |
| t.Fatalf("Error creating HTTP Client: %v", err) |
| } |
| |
| rt, ok := client.Transport.(*basicAuthRoundTripper) |
| if !ok { |
| t.Fatalf("Error casting to basic auth transport, %v", client.Transport) |
| } |
| |
| if rt.username != "" { |
| t.Errorf("Got unexpected username: %s", rt.username) |
| } |
| if string(rt.password) != "secret" { |
| t.Errorf("Unexpected HTTP client password: %s", string(rt.password)) |
| } |
| if string(rt.passwordFile) != "" { |
| t.Errorf("Expected empty HTTP client passwordFile: %s", rt.passwordFile) |
| } |
| } |
| |
| func TestBasicAuthPasswordFile(t *testing.T) { |
| cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.good.yaml") |
| if err != nil { |
| t.Fatalf("Error loading HTTP client config: %v", err) |
| } |
| client, err := NewClientFromConfig(*cfg, "test") |
| if err != nil { |
| t.Fatalf("Error creating HTTP Client: %v", err) |
| } |
| |
| rt, ok := client.Transport.(*basicAuthRoundTripper) |
| if !ok { |
| t.Fatalf("Error casting to basic auth transport, %v", client.Transport) |
| } |
| |
| if rt.username != "user" { |
| t.Errorf("Bad HTTP client username: %s", rt.username) |
| } |
| if string(rt.password) != "" { |
| t.Errorf("Bad HTTP client password: %s", rt.password) |
| } |
| if string(rt.passwordFile) != "testdata/basic-auth-password" { |
| t.Errorf("Bad HTTP client passwordFile: %s", rt.passwordFile) |
| } |
| } |
| |
| func getCertificateBlobs(t *testing.T) map[string][]byte { |
| files := []string{ |
| TLSCAChainPath, |
| ClientCertificatePath, |
| ClientKeyNoPassPath, |
| ServerCertificatePath, |
| ServerKeyPath, |
| WrongClientCertPath, |
| WrongClientKeyPath, |
| EmptyFile, |
| } |
| bs := make(map[string][]byte, len(files)+1) |
| for _, f := range files { |
| b, err := ioutil.ReadFile(f) |
| if err != nil { |
| t.Fatal(err) |
| } |
| bs[f] = b |
| } |
| |
| return bs |
| } |
| |
| func writeCertificate(bs map[string][]byte, src string, dst string) { |
| b, ok := bs[src] |
| if !ok { |
| panic(fmt.Sprintf("Couldn't find %q in bs", src)) |
| } |
| if err := ioutil.WriteFile(dst, b, 0664); err != nil { |
| panic(err) |
| } |
| } |
| |
| func TestTLSRoundTripper(t *testing.T) { |
| bs := getCertificateBlobs(t) |
| |
| tmpDir, err := ioutil.TempDir("", "tlsroundtripper") |
| if err != nil { |
| t.Fatal("Failed to create tmp dir", err) |
| } |
| defer os.RemoveAll(tmpDir) |
| |
| ca, cert, key := filepath.Join(tmpDir, "ca"), filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key") |
| |
| handler := func(w http.ResponseWriter, r *http.Request) { |
| fmt.Fprint(w, ExpectedMessage) |
| } |
| testServer, err := newTestServer(handler) |
| if err != nil { |
| t.Fatal(err.Error()) |
| } |
| defer testServer.Close() |
| |
| testCases := []struct { |
| ca string |
| cert string |
| key string |
| |
| errMsg string |
| }{ |
| { |
| // Valid certs. |
| ca: TLSCAChainPath, |
| cert: ClientCertificatePath, |
| key: ClientKeyNoPassPath, |
| }, |
| { |
| // CA not matching. |
| ca: ClientCertificatePath, |
| cert: ClientCertificatePath, |
| key: ClientKeyNoPassPath, |
| |
| errMsg: "certificate signed by unknown authority", |
| }, |
| { |
| // Invalid client cert+key. |
| ca: TLSCAChainPath, |
| cert: WrongClientCertPath, |
| key: WrongClientKeyPath, |
| |
| errMsg: "remote error: tls", |
| }, |
| { |
| // CA file empty |
| ca: EmptyFile, |
| cert: ClientCertificatePath, |
| key: ClientKeyNoPassPath, |
| |
| errMsg: "unable to use specified CA cert", |
| }, |
| { |
| // cert file empty |
| ca: TLSCAChainPath, |
| cert: EmptyFile, |
| key: ClientKeyNoPassPath, |
| |
| errMsg: "failed to find any PEM data in certificate input", |
| }, |
| { |
| // key file empty |
| ca: TLSCAChainPath, |
| cert: ClientCertificatePath, |
| key: EmptyFile, |
| |
| errMsg: "failed to find any PEM data in key input", |
| }, |
| { |
| // Valid certs again. |
| ca: TLSCAChainPath, |
| cert: ClientCertificatePath, |
| key: ClientKeyNoPassPath, |
| }, |
| } |
| |
| cfg := HTTPClientConfig{ |
| TLSConfig: TLSConfig{ |
| CAFile: ca, |
| CertFile: cert, |
| KeyFile: key, |
| InsecureSkipVerify: false}, |
| } |
| |
| var c *http.Client |
| for i, tc := range testCases { |
| tc := tc |
| t.Run(strconv.Itoa(i), func(t *testing.T) { |
| writeCertificate(bs, tc.ca, ca) |
| writeCertificate(bs, tc.cert, cert) |
| writeCertificate(bs, tc.key, key) |
| if c == nil { |
| c, err = NewClientFromConfig(cfg, "test") |
| if err != nil { |
| t.Fatalf("Error creating HTTP Client: %v", err) |
| } |
| } |
| |
| req, err := http.NewRequest(http.MethodGet, testServer.URL, nil) |
| if err != nil { |
| t.Fatalf("Error creating HTTP request: %v", err) |
| } |
| r, err := c.Do(req) |
| if len(tc.errMsg) > 0 { |
| if err == nil { |
| r.Body.Close() |
| t.Fatalf("Could connect to the test server.") |
| } |
| if !strings.Contains(err.Error(), tc.errMsg) { |
| t.Fatalf("Expected error message to contain %q, got %q", tc.errMsg, err) |
| } |
| return |
| } |
| |
| if err != nil { |
| t.Fatalf("Can't connect to the test server") |
| } |
| |
| b, err := ioutil.ReadAll(r.Body) |
| r.Body.Close() |
| if err != nil { |
| t.Errorf("Can't read the server response body") |
| } |
| |
| got := strings.TrimSpace(string(b)) |
| if ExpectedMessage != got { |
| t.Errorf("The expected message %q differs from the obtained message %q", ExpectedMessage, got) |
| } |
| }) |
| } |
| } |
| |
| func TestTLSRoundTripperRaces(t *testing.T) { |
| bs := getCertificateBlobs(t) |
| |
| tmpDir, err := ioutil.TempDir("", "tlsroundtripper") |
| if err != nil { |
| t.Fatal("Failed to create tmp dir", err) |
| } |
| defer os.RemoveAll(tmpDir) |
| |
| ca, cert, key := filepath.Join(tmpDir, "ca"), filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key") |
| |
| handler := func(w http.ResponseWriter, r *http.Request) { |
| fmt.Fprint(w, ExpectedMessage) |
| } |
| testServer, err := newTestServer(handler) |
| if err != nil { |
| t.Fatal(err.Error()) |
| } |
| defer testServer.Close() |
| |
| cfg := HTTPClientConfig{ |
| TLSConfig: TLSConfig{ |
| CAFile: ca, |
| CertFile: cert, |
| KeyFile: key, |
| InsecureSkipVerify: false}, |
| } |
| |
| var c *http.Client |
| writeCertificate(bs, TLSCAChainPath, ca) |
| writeCertificate(bs, ClientCertificatePath, cert) |
| writeCertificate(bs, ClientKeyNoPassPath, key) |
| c, err = NewClientFromConfig(cfg, "test") |
| if err != nil { |
| t.Fatalf("Error creating HTTP Client: %v", err) |
| } |
| |
| var wg sync.WaitGroup |
| ch := make(chan struct{}) |
| var total, ok int64 |
| // Spawn 10 Go routines polling the server concurrently. |
| for i := 0; i < 10; i++ { |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| for { |
| select { |
| case <-ch: |
| return |
| default: |
| atomic.AddInt64(&total, 1) |
| r, err := c.Get(testServer.URL) |
| if err == nil { |
| r.Body.Close() |
| atomic.AddInt64(&ok, 1) |
| } |
| } |
| } |
| }() |
| } |
| |
| // Change the CA file every 10ms for 1 second. |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| i := 0 |
| for { |
| tick := time.NewTicker(10 * time.Millisecond) |
| <-tick.C |
| if i%2 == 0 { |
| writeCertificate(bs, ClientCertificatePath, ca) |
| } else { |
| writeCertificate(bs, TLSCAChainPath, ca) |
| } |
| i++ |
| if i > 100 { |
| close(ch) |
| return |
| } |
| } |
| }() |
| |
| wg.Wait() |
| if ok == total { |
| t.Fatalf("Expecting some requests to fail but got %d/%d successful requests", ok, total) |
| } |
| } |
| |
| func TestHideHTTPClientConfigSecrets(t *testing.T) { |
| c, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml") |
| if err != nil { |
| t.Errorf("Error parsing %s: %s", "testdata/http.conf.good.yml", err) |
| } |
| |
| // String method must not reveal authentication credentials. |
| s := c.String() |
| if strings.Contains(s, "mysecret") { |
| t.Fatal("http client config's String method reveals authentication credentials.") |
| } |
| } |
| |
| func TestDefaultFollowRedirect(t *testing.T) { |
| cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml") |
| if err != nil { |
| t.Errorf("Error loading HTTP client config: %v", err) |
| } |
| if !cfg.FollowRedirects { |
| t.Errorf("follow_redirects should be true") |
| } |
| } |
| |
| func TestValidateHTTPConfig(t *testing.T) { |
| cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml") |
| if err != nil { |
| t.Errorf("Error loading HTTP client config: %v", err) |
| } |
| err = cfg.Validate() |
| if err != nil { |
| t.Fatalf("Error validating %s: %s", "testdata/http.conf.good.yml", err) |
| } |
| } |
| |
| func TestInvalidHTTPConfigs(t *testing.T) { |
| for _, ee := range invalidHTTPClientConfigs { |
| _, _, err := LoadHTTPConfigFile(ee.httpClientConfigFile) |
| if err == nil { |
| t.Error("Expected error with config but got none") |
| continue |
| } |
| if !strings.Contains(err.Error(), ee.errMsg) { |
| t.Errorf("Expected error for invalid HTTP client configuration to contain %q but got: %s", ee.errMsg, err) |
| } |
| } |
| } |
| |
| // LoadHTTPConfig parses the YAML input s into a HTTPClientConfig. |
| func LoadHTTPConfig(s string) (*HTTPClientConfig, error) { |
| cfg := &HTTPClientConfig{} |
| err := yaml.UnmarshalStrict([]byte(s), cfg) |
| if err != nil { |
| return nil, err |
| } |
| return cfg, nil |
| } |
| |
| // LoadHTTPConfigFile parses the given YAML file into a HTTPClientConfig. |
| func LoadHTTPConfigFile(filename string) (*HTTPClientConfig, []byte, error) { |
| content, err := ioutil.ReadFile(filename) |
| if err != nil { |
| return nil, nil, err |
| } |
| cfg, err := LoadHTTPConfig(string(content)) |
| if err != nil { |
| return nil, nil, err |
| } |
| return cfg, content, nil |
| } |
| |
| type roundTrip struct { |
| theResponse *http.Response |
| theError error |
| } |
| |
| func (rt *roundTrip) RoundTrip(r *http.Request) (*http.Response, error) { |
| return rt.theResponse, rt.theError |
| } |
| |
| type roundTripCheckRequest struct { |
| checkRequest func(*http.Request) |
| roundTrip |
| } |
| |
| func (rt *roundTripCheckRequest) RoundTrip(r *http.Request) (*http.Response, error) { |
| rt.checkRequest(r) |
| return rt.theResponse, rt.theError |
| } |
| |
| // NewRoundTripCheckRequest creates a new instance of a type that implements http.RoundTripper, |
| // which before returning theResponse and theError, executes checkRequest against a http.Request. |
| func NewRoundTripCheckRequest(checkRequest func(*http.Request), theResponse *http.Response, theError error) http.RoundTripper { |
| return &roundTripCheckRequest{ |
| checkRequest: checkRequest, |
| roundTrip: roundTrip{ |
| theResponse: theResponse, |
| theError: theError}} |
| } |
| |
| type testServerResponse struct { |
| AccessToken string `json:"access_token"` |
| TokenType string `json:"token_type"` |
| } |
| |
| func TestOAuth2(t *testing.T) { |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| res, _ := json.Marshal(testServerResponse{ |
| AccessToken: "12345", |
| TokenType: "Bearer", |
| }) |
| w.Header().Add("Content-Type", "application/json") |
| _, _ = w.Write(res) |
| })) |
| defer ts.Close() |
| |
| var yamlConfig = fmt.Sprintf(` |
| client_id: 1 |
| client_secret: 2 |
| scopes: |
| - A |
| - B |
| token_url: %s |
| endpoint_params: |
| hi: hello |
| `, ts.URL) |
| expectedConfig := OAuth2{ |
| ClientID: "1", |
| ClientSecret: "2", |
| Scopes: []string{"A", "B"}, |
| EndpointParams: map[string]string{"hi": "hello"}, |
| TokenURL: ts.URL, |
| } |
| |
| var unmarshalledConfig OAuth2 |
| err := yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig) |
| if err != nil { |
| t.Fatalf("Expected no error unmarshalling yaml, got %v", err) |
| } |
| if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) { |
| t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig) |
| } |
| |
| rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport) |
| |
| client := http.Client{ |
| Transport: rt, |
| } |
| resp, _ := client.Get(ts.URL) |
| |
| authorization := resp.Request.Header.Get("Authorization") |
| if authorization != "Bearer 12345" { |
| t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) |
| } |
| } |
| |
| func TestOAuth2WithFile(t *testing.T) { |
| var expectedAuth *string |
| var previousAuth string |
| tokenTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| auth := r.Header.Get("Authorization") |
| if auth != *expectedAuth { |
| t.Fatalf("bad auth, expected %s, got %s", *expectedAuth, auth) |
| } |
| if auth == previousAuth { |
| t.Fatal("token endpoint called twice") |
| } |
| previousAuth = auth |
| res, _ := json.Marshal(testServerResponse{ |
| AccessToken: "12345", |
| TokenType: "Bearer", |
| }) |
| w.Header().Add("Content-Type", "application/json") |
| _, _ = w.Write(res) |
| })) |
| defer tokenTS.Close() |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| auth := r.Header.Get("Authorization") |
| if auth != "Bearer 12345" { |
| t.Fatalf("bad auth, expected %s, got %s", "Bearer 12345", auth) |
| } |
| fmt.Fprintln(w, "Hello, client") |
| })) |
| defer ts.Close() |
| |
| secretFile, err := ioutil.TempFile("", "oauth2_secret") |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer os.Remove(secretFile.Name()) |
| |
| var yamlConfig = fmt.Sprintf(` |
| client_id: 1 |
| client_secret_file: %s |
| scopes: |
| - A |
| - B |
| token_url: %s |
| endpoint_params: |
| hi: hello |
| `, secretFile.Name(), tokenTS.URL) |
| expectedConfig := OAuth2{ |
| ClientID: "1", |
| ClientSecretFile: secretFile.Name(), |
| Scopes: []string{"A", "B"}, |
| EndpointParams: map[string]string{"hi": "hello"}, |
| TokenURL: tokenTS.URL, |
| } |
| |
| var unmarshalledConfig OAuth2 |
| err = yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig) |
| if err != nil { |
| t.Fatalf("Expected no error unmarshalling yaml, got %v", err) |
| } |
| if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) { |
| t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig) |
| } |
| |
| rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport) |
| |
| client := http.Client{ |
| Transport: rt, |
| } |
| |
| tk := "Basic MToxMjM0NTY=" |
| expectedAuth = &tk |
| if _, err := secretFile.Write([]byte("123456")); err != nil { |
| t.Fatal(err) |
| } |
| resp, err := client.Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| authorization := resp.Request.Header.Get("Authorization") |
| if authorization != "Bearer 12345" { |
| t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) |
| } |
| |
| // Making a second request with the same file content should not re-call the token API. |
| resp, err = client.Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| tk = "Basic MToxMjM0NTY3" |
| expectedAuth = &tk |
| if _, err := secretFile.Write([]byte("7")); err != nil { |
| t.Fatal(err) |
| } |
| |
| _, err = client.Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // Making a second request with the same file content should not re-call the token API. |
| _, err = client.Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| authorization = resp.Request.Header.Get("Authorization") |
| if authorization != "Bearer 12345" { |
| t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) |
| } |
| } |
| |
| func TestMarshalURL(t *testing.T) { |
| urlp, err := url.Parse("http://example.com/") |
| if err != nil { |
| t.Fatal(err) |
| } |
| u := &URL{urlp} |
| |
| c, err := json.Marshal(u) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if string(c) != "\"http://example.com/\"" { |
| t.Fatalf("URL not properly marshaled in JSON got '%s'", string(c)) |
| } |
| |
| c, err = yaml.Marshal(u) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if string(c) != "http://example.com/\n" { |
| t.Fatalf("URL not properly marshaled in YAML got '%s'", string(c)) |
| } |
| } |
| |
| func TestMarshalURLWrapperWithNilValue(t *testing.T) { |
| u := &URL{} |
| |
| c, err := json.Marshal(u) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if string(c) != "null" { |
| t.Fatalf("URL with nil value not properly marshaled into JSON, got %q", c) |
| } |
| |
| c, err = yaml.Marshal(u) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if string(c) != "null\n" { |
| t.Fatalf("URL with nil value not properly marshaled into JSON, got %q", c) |
| } |
| } |
| |
| func TestUnmarshalNullURL(t *testing.T) { |
| b := []byte(`null`) |
| |
| { |
| var u URL |
| err := json.Unmarshal(b, &u) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if !isEmptyNonNilURL(u.URL) { |
| t.Fatalf("`null` literal not properly unmarshaled from JSON as URL, got %#v", u.URL) |
| } |
| } |
| |
| { |
| var u URL |
| err := yaml.Unmarshal(b, &u) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if u.URL != nil { // UnmarshalYAML is not called when parsing null literal. |
| t.Fatalf("`null` literal not properly unmarshaled from YAML as URL, got %#v", u.URL) |
| } |
| } |
| } |
| |
| func TestUnmarshalEmptyURL(t *testing.T) { |
| b := []byte(`""`) |
| |
| { |
| var u URL |
| err := json.Unmarshal(b, &u) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if !isEmptyNonNilURL(u.URL) { |
| t.Fatalf("empty string not properly unmarshaled from JSON as URL, got %#v", u.URL) |
| } |
| } |
| |
| { |
| var u URL |
| err := yaml.Unmarshal(b, &u) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if !isEmptyNonNilURL(u.URL) { |
| t.Fatalf("empty string not properly unmarshaled from YAML as URL, got %#v", u.URL) |
| } |
| } |
| } |
| |
| // checks if u equals to &url.URL{} |
| func isEmptyNonNilURL(u *url.URL) bool { |
| return u != nil && *u == url.URL{} |
| } |
| |
| func TestUnmarshalURL(t *testing.T) { |
| b := []byte(`"http://example.com/a b"`) |
| var u URL |
| |
| err := json.Unmarshal(b, &u) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if u.String() != "http://example.com/a%20b" { |
| t.Fatalf("URL not properly unmarshaled in JSON, got '%s'", u.String()) |
| } |
| |
| err = yaml.Unmarshal(b, &u) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if u.String() != "http://example.com/a%20b" { |
| t.Fatalf("URL not properly unmarshaled in YAML, got '%s'", u.String()) |
| } |
| } |