blob: a8d6e40a5770f2430467a0c0a27a1b41120a88ef [file] [log] [blame]
// Copyright 2015 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package lhttp
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
"go.chromium.org/luci/common/retry"
. "github.com/smartystreets/goconvey/convey"
)
func httpReqGen(method, url string, body []byte) RequestGen {
return func() (*http.Request, error) {
var bodyReader io.Reader
if body != nil {
bodyReader = bytes.NewReader(body)
}
return http.NewRequest("GET", url, bodyReader)
}
}
func TestNewRequestGET(t *testing.T) {
Convey(`HTTP GET requests should be handled correctly.`, t, func(c C) {
ctx := context.Background()
// First call returns HTTP 500, second succeeds.
serverCalls := 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
serverCalls++
content, err := io.ReadAll(r.Body)
c.So(err, ShouldBeNil)
c.So(content, ShouldResemble, []byte{})
if serverCalls == 1 {
w.WriteHeader(500)
} else {
fmt.Fprintf(w, "Hello, client\n")
}
}))
defer ts.Close()
httpReq := httpReqGen("GET", ts.URL, nil)
clientCalls := 0
clientReq := NewRequest(ctx, http.DefaultClient, fast, httpReq, func(resp *http.Response) error {
clientCalls++
content, err := io.ReadAll(resp.Body)
So(err, ShouldBeNil)
So(string(content), ShouldResemble, "Hello, client\n")
So(resp.Body.Close(), ShouldBeNil)
return nil
}, nil)
status, err := clientReq()
So(err, ShouldBeNil)
So(status, ShouldResemble, 200)
So(serverCalls, ShouldResemble, 2)
So(clientCalls, ShouldResemble, 1)
})
}
func TestNewRequestPOST(t *testing.T) {
Convey(`HTTP POST requests should be handled correctly.`, t, func(c C) {
ctx := context.Background()
// First call returns HTTP 500, second succeeds.
serverCalls := 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
serverCalls++
content, err := io.ReadAll(r.Body)
c.So(err, ShouldBeNil)
// The same data is sent twice.
c.So(string(content), ShouldResemble, "foo bar")
if serverCalls == 1 {
w.WriteHeader(500)
} else {
fmt.Fprintf(w, "Hello, client\n")
}
}))
defer ts.Close()
httpReq := httpReqGen("POST", ts.URL, []byte("foo bar"))
clientCalls := 0
clientReq := NewRequest(ctx, http.DefaultClient, fast, httpReq, func(resp *http.Response) error {
clientCalls++
content, err := io.ReadAll(resp.Body)
So(err, ShouldBeNil)
So(string(content), ShouldResemble, "Hello, client\n")
So(resp.Body.Close(), ShouldBeNil)
return nil
}, nil)
status, err := clientReq()
So(err, ShouldBeNil)
So(status, ShouldResemble, 200)
So(serverCalls, ShouldResemble, 2)
So(clientCalls, ShouldResemble, 1)
})
}
func TestNewRequestGETFail(t *testing.T) {
Convey(`HTTP GET requests should handle failure successfully.`, t, func() {
ctx := context.Background()
serverCalls := 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
serverCalls++
w.WriteHeader(500)
}))
defer ts.Close()
httpReq := httpReqGen("GET", ts.URL, nil)
clientReq := NewRequest(ctx, http.DefaultClient, fast, httpReq, func(resp *http.Response) error {
t.Fail()
return nil
}, nil)
status, err := clientReq()
So(err.Error(), ShouldResemble, "gave up after 4 attempts: http request failed: Internal Server Error (HTTP 500)")
So(status, ShouldResemble, 500)
})
}
func TestNewRequestDefaultFactory(t *testing.T) {
// Test that the default factory (rFn == nil) only retries for transient
// HTTP errors.
testCases := []struct {
statusCode int // The status code to return (the first 2 times).
path string // Request path, if any.
wantErr bool // Whether we want NewRequest to return an error.
wantCalls int // The total number of HTTP requests expected.
}{
// 200, passes immediately.
{statusCode: 200, wantErr: false, wantCalls: 1},
// Transient HTTP error codes that will retry.
{statusCode: 408, wantErr: false, wantCalls: 3},
{statusCode: 500, wantErr: false, wantCalls: 3},
{statusCode: 503, wantErr: false, wantCalls: 3},
// Immediate failure codes.
{statusCode: 403, wantErr: true, wantCalls: 1},
{statusCode: 404, wantErr: true, wantCalls: 1},
}
ctx := context.Background()
for _, tc := range testCases {
tc := tc
t.Run(fmt.Sprintf("Status code %d, path %q", tc.statusCode, tc.path), func(t *testing.T) {
t.Parallel()
serverCalls := 0
ts := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
serverCalls++
if serverCalls <= 2 {
w.WriteHeader(tc.statusCode)
}
fmt.Fprintf(w, "Hello World!\n")
}))
defer ts.Close()
httpReq := httpReqGen("GET", ts.URL+tc.path, nil)
req := NewRequest(ctx, http.DefaultClient, nil, httpReq, func(resp *http.Response) error {
return resp.Body.Close()
}, nil)
_, err := req()
if err == nil && tc.wantErr {
t.Error("req returned nil error, wanted an error")
} else if err != nil && !tc.wantErr {
t.Errorf("req returned err %v, wanted nil", err)
}
if got, want := serverCalls, tc.wantCalls; got != want {
t.Errorf("total server calls; got %d, want %d", got, want)
}
})
}
}
func TestNewRequestClosesBody(t *testing.T) {
ctx := context.Background()
serverCalls := 0
// Return a 500 for the first 2 requests.
ts := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
serverCalls++
if serverCalls <= 2 {
w.WriteHeader(500)
}
fmt.Fprintf(w, "Hello World!\n")
}))
defer ts.Close()
rt := &trackingRoundTripper{RoundTripper: http.DefaultTransport}
hc := &http.Client{Transport: rt}
httpReq := httpReqGen("GET", ts.URL, nil)
clientCalls := 0
var lastResp *http.Response
req := NewRequest(ctx, hc, fast, httpReq, func(resp *http.Response) error {
clientCalls++
lastResp = resp
return resp.Body.Close()
}, nil)
status, err := req()
if err != nil {
t.Fatalf("req returned err %v, want nil", err)
}
if got, want := status, http.StatusOK; got != want {
t.Errorf("req returned status %d, want %d", got, want)
}
// We expect only one client call, but three requests through to the server.
if got, want := clientCalls, 1; got != want {
t.Errorf("handler callback invoked %d times, want %d", got, want)
}
if got, want := len(rt.Responses), 3; got != want {
t.Errorf("len(Responses) = %d, want %d", got, want)
}
// Check that the last response is the one we handled, and that all the bodies
// were closed.
if got, want := lastResp, rt.Responses[2]; got != want {
t.Errorf("Last Response did not match Response in handler callback.\nGot: %v\nWant: %v", got, want)
}
for i, resp := range rt.Responses {
rc := resp.Body.(*trackingReadCloser)
if !rc.Closed {
t.Errorf("Responses[%d].Body was not closed", i)
}
}
}
// trackingRoundTripper wraps an http.RoundTripper, keeping track of any
// returned Responses. Each response's Body, when set, is wrapped with a
// trackingReadCloser.
type trackingRoundTripper struct {
http.RoundTripper
mu sync.Mutex
Responses []*http.Response
}
func (t *trackingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := t.RoundTripper.RoundTrip(req)
if resp != nil && resp.Body != nil {
resp.Body = &trackingReadCloser{ReadCloser: resp.Body}
}
t.mu.Lock()
defer t.mu.Unlock()
t.Responses = append(t.Responses, resp)
return resp, err
}
// trackingReadCloser wraps an io.ReadCloser, keeping track of whether Closed was
// called.
type trackingReadCloser struct {
io.ReadCloser
Closed bool
}
func (t *trackingReadCloser) Close() error {
t.Closed = true
return t.ReadCloser.Close()
}
// Private details.
func fast() retry.Iterator {
return &retry.Limited{Retries: 3}
}