blob: 8dea1dfa93b4741c373cfd056af337bd847b370c [file] [log] [blame]
// Copyright 2016 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package prpc
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
"github.com/klauspost/compress/gzip"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"go.chromium.org/luci/common/clock"
"go.chromium.org/luci/common/clock/testclock"
"go.chromium.org/luci/common/logging"
"go.chromium.org/luci/common/logging/memlogger"
"go.chromium.org/luci/common/retry"
. "github.com/smartystreets/goconvey/convey"
. "go.chromium.org/luci/common/testing/assertions"
)
func sayHello(c C) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
c.So(r.Method, ShouldEqual, "POST")
c.So(r.URL.Path == "/prpc/prpc.Greeter/SayHello" || r.URL.Path == "/python/prpc/prpc.Greeter/SayHello", ShouldBeTrue)
c.So(r.Header.Get("Content-Type"), ShouldEqual, "application/prpc; encoding=binary")
c.So(r.Header.Get("User-Agent"), ShouldEqual, "prpc-test")
if timeout := r.Header.Get(HeaderTimeout); timeout != "" {
c.So(timeout, ShouldEqual, "10000000u")
}
reqBody, err := io.ReadAll(r.Body)
c.So(err, ShouldBeNil)
var req HelloRequest
err = proto.Unmarshal(reqBody, &req)
c.So(err, ShouldBeNil)
if req.Name == "TOO BIG" {
w.Header().Set("Content-Length", "999999999999")
}
w.Header().Set("X-Lower-Case-Header", "CamelCaseValueStays")
res := HelloReply{Message: "Hello " + req.Name}
if r.URL.Path == "/python/prpc/prpc.Greeter/SayHello" {
res.Message = res.Message + " from python service"
}
var buf []byte
if req.Name == "ACCEPT JSONPB" {
c.So(r.Header.Get("Accept"), ShouldEqual, "application/json")
sbuf, err := (&jsonpb.Marshaler{}).MarshalToString(&res)
c.So(err, ShouldBeNil)
buf = []byte(sbuf)
} else {
c.So(r.Header.Get("Accept"), ShouldEqual, "application/prpc; encoding=binary")
buf, err = proto.Marshal(&res)
c.So(err, ShouldBeNil)
}
code := codes.OK
status := http.StatusOK
if req.Name == "NOT FOUND" {
code = codes.NotFound
status = http.StatusNotFound
}
w.Header().Set("Content-Type", r.Header.Get("Accept"))
w.Header().Set(HeaderGRPCCode, strconv.Itoa(int(code)))
w.WriteHeader(status)
_, err = w.Write(buf)
c.So(err, ShouldBeNil)
}
}
func doPanicHandler(w http.ResponseWriter, r *http.Request) {
panic("test panic")
}
func transientErrors(count int, grpcHeader bool, httpStatus int, then http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if count > 0 {
count--
if grpcHeader {
w.Header().Set(HeaderGRPCCode, strconv.Itoa(int(codes.Internal)))
}
w.WriteHeader(httpStatus)
fmt.Fprintln(w, "Server misbehaved")
return
}
then.ServeHTTP(w, r)
}
}
func advanceClockAndErr(tc testclock.TestClock, d time.Duration) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
tc.Add(d)
w.WriteHeader(http.StatusInternalServerError)
}
}
func shouldHaveMessagesLike(actual any, expected ...any) string {
log := actual.(*memlogger.MemLogger)
msgs := log.Messages()
So(msgs, ShouldHaveLength, len(expected))
for i, actual := range msgs {
expected := expected[i].(memlogger.LogEntry)
So(actual.Level, ShouldEqual, expected.Level)
So(actual.Msg, ShouldContainSubstring, expected.Msg)
}
return ""
}
func TestClient(t *testing.T) {
t.Parallel()
setUp := func(h http.HandlerFunc) (*Client, *httptest.Server) {
server := httptest.NewServer(h)
client := &Client{
Host: strings.TrimPrefix(server.URL, "http://"),
Options: &Options{
Retry: func() retry.Iterator {
return &retry.Limited{
Retries: 3,
Delay: 0,
}
},
Insecure: true,
UserAgent: "prpc-test",
},
}
return client, server
}
Convey("Client", t, func() {
// These unit tests use real HTTP connections to localhost. Since go 1.7
// 'net/http' library uses the context deadline to derive the connection
// timeout: it grabs the deadline (as time.Time) from the context and
// compares it to the current time. So we can't put arbitrary mocked time
// into the testclock (as it ends up in the context deadline passed to
// 'net/http'). We either have to use real clock in the unit tests, or
// "freeze" the time at the real "now" value.
ctx, tc := testclock.UseTime(context.Background(), time.Now().Local())
ctx = memlogger.Use(ctx)
log := logging.Get(ctx).(*memlogger.MemLogger)
expectedCallLogEntry := func(c *Client) memlogger.LogEntry {
return memlogger.LogEntry{
Level: logging.Debug,
Msg: fmt.Sprintf("RPC %s/prpc.Greeter.SayHello", c.Host),
}
}
req := &HelloRequest{Name: "John"}
res := &HelloReply{}
Convey("Call", func() {
Convey("Works", func(c C) {
client, server := setUp(sayHello(c))
defer server.Close()
var hd metadata.MD
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res, grpc.Header(&hd))
So(err, ShouldBeNil)
So(res.Message, ShouldEqual, "Hello John")
So(hd["x-lower-case-header"], ShouldResemble, []string{"CamelCaseValueStays"})
So(log, shouldHaveMessagesLike, expectedCallLogEntry(client))
})
Convey("Works with PathPrefix", func(c C) {
client, server := setUp(sayHello(c))
defer server.Close()
client.PathPrefix = "/python/prpc"
var hd metadata.MD
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res, grpc.Header(&hd))
So(err, ShouldBeNil)
So(res.Message, ShouldEqual, "Hello John from python service")
})
Convey("Works with response in JSONPB", func(c C) {
req.Name = "ACCEPT JSONPB"
client, server := setUp(sayHello(c))
client.Options.AcceptContentSubtype = "json"
defer server.Close()
var hd metadata.MD
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res, grpc.Header(&hd))
So(err, ShouldBeNil)
So(res.Message, ShouldEqual, "Hello ACCEPT JSONPB")
So(hd["x-lower-case-header"], ShouldResemble, []string{"CamelCaseValueStays"})
So(log, shouldHaveMessagesLike, expectedCallLogEntry(client))
})
Convey("With outgoing metadata", func(c C) {
var receivedHeader http.Header
greeter := sayHello(c)
client, server := setUp(func(w http.ResponseWriter, r *http.Request) {
receivedHeader = r.Header
greeter(w, r)
})
defer server.Close()
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(
"key", "value 1",
"key", "value 2",
"data-bin", string([]byte{0, 1, 2, 3}),
))
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(err, ShouldBeNil)
So(receivedHeader["Key"], ShouldResemble, []string{"value 1", "value 2"})
So(receivedHeader["Data-Bin"], ShouldResemble, []string{"AAECAw=="})
})
Convey("Works with compression", func(c C) {
req := &HelloRequest{Name: strings.Repeat("A", 1024)}
client, server := setUp(func(w http.ResponseWriter, r *http.Request) {
// Parse request.
c.So(r.Header.Get("Accept-Encoding"), ShouldEqual, "gzip")
c.So(r.Header.Get("Content-Encoding"), ShouldEqual, "gzip")
gz, err := gzip.NewReader(r.Body)
c.So(err, ShouldBeNil)
defer gz.Close()
reqBody, err := io.ReadAll(gz)
c.So(err, ShouldBeNil)
var actualReq HelloRequest
err = proto.Unmarshal(reqBody, &actualReq)
c.So(err, ShouldBeNil)
c.So(&actualReq, ShouldResembleProto, req)
// Write response.
resBytes, err := proto.Marshal(&HelloReply{Message: "compressed response"})
c.So(err, ShouldBeNil)
resBody, err := compressBlob(resBytes)
c.So(err, ShouldBeNil)
w.Header().Set("Content-Type", mtPRPCBinary)
w.Header().Set("Content-Encoding", "gzip")
w.Header().Set(HeaderGRPCCode, "0")
w.WriteHeader(http.StatusOK)
_, err = w.Write(resBody)
c.So(err, ShouldBeNil)
})
defer server.Close()
client.EnableRequestCompression = true
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(err, ShouldBeNil)
So(res.Message, ShouldEqual, "compressed response")
})
Convey("With a deadline <= now, does not execute.", func(c C) {
client, server := setUp(doPanicHandler)
defer server.Close()
ctx, cancelFunc := clock.WithDeadline(ctx, clock.Now(ctx))
defer cancelFunc()
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(status.Code(err), ShouldEqual, codes.DeadlineExceeded)
So(err, ShouldErrLike, "overall deadline exceeded")
})
Convey("With a deadline in the future, sets the deadline header.", func(c C) {
client, server := setUp(sayHello(c))
defer server.Close()
ctx, cancelFunc := clock.WithDeadline(ctx, clock.Now(ctx).Add(10*time.Second))
defer cancelFunc()
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(err, ShouldBeNil)
So(res.Message, ShouldEqual, "Hello John")
So(log, shouldHaveMessagesLike, expectedCallLogEntry(client))
})
Convey("With a deadline in the future and a per-RPC deadline, applies the per-RPC deadline", func(c C) {
// Set an overall deadline.
overallDeadline := time.Second + 500*time.Millisecond
ctx, cancel := clock.WithTimeout(ctx, overallDeadline)
defer cancel()
client, server := setUp(advanceClockAndErr(tc, time.Second))
defer server.Close()
calls := 0
// All of our HTTP requests should terminate >= timeout. Synchronize
// around this to ensure that our Context is always the functional
// client error.
client.testPostHTTP = func(ctx context.Context, err error) error {
calls++
<-ctx.Done()
return ctx.Err()
}
client.Options.PerRPCTimeout = time.Second
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(status.Code(err), ShouldEqual, codes.DeadlineExceeded)
So(err, ShouldErrLike, "overall deadline exceeded")
So(calls, ShouldEqual, 2)
})
Convey(`With a maximum content length smaller than the response, returns "ErrResponseTooBig".`, func(c C) {
client, server := setUp(sayHello(c))
defer server.Close()
client.MaxContentLength = 8
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(err, ShouldEqual, ErrResponseTooBig)
})
Convey(`When the response returns a huge Content Length, returns "ErrResponseTooBig".`, func(c C) {
client, server := setUp(sayHello(c))
defer server.Close()
req.Name = "TOO BIG"
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(err, ShouldEqual, ErrResponseTooBig)
})
Convey("Doesn't log expected codes", func(c C) {
client, server := setUp(sayHello(c))
defer server.Close()
req.Name = "NOT FOUND"
// Have it logged by default
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(status.Code(err), ShouldEqual, codes.NotFound)
So(log, shouldHaveMessagesLike,
expectedCallLogEntry(client),
memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed permanently"})
log.Reset()
// And don't have it if using ExpectedCode.
err = client.Call(ctx, "prpc.Greeter", "SayHello", req, res, ExpectedCode(codes.NotFound))
So(status.Code(err), ShouldEqual, codes.NotFound)
So(log, shouldHaveMessagesLike, expectedCallLogEntry(client))
})
Convey("HTTP 500 x2", func(c C) {
client, server := setUp(transientErrors(2, true, http.StatusInternalServerError, sayHello(c)))
defer server.Close()
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(err, ShouldBeNil)
So(res.Message, ShouldEqual, "Hello John")
So(log, shouldHaveMessagesLike,
expectedCallLogEntry(client),
memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
expectedCallLogEntry(client),
memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
expectedCallLogEntry(client),
)
})
Convey("HTTP 500 many", func(c C) {
client, server := setUp(transientErrors(10, true, http.StatusInternalServerError, sayHello(c)))
defer server.Close()
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(status.Code(err), ShouldEqual, codes.Internal)
So(status.Convert(err).Message(), ShouldEqual, "Server misbehaved")
So(log, shouldHaveMessagesLike,
expectedCallLogEntry(client),
memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
expectedCallLogEntry(client),
memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
expectedCallLogEntry(client),
memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
expectedCallLogEntry(client),
memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed permanently"},
)
})
Convey("HTTP 500 without gRPC header", func(c C) {
client, server := setUp(transientErrors(10, false, http.StatusInternalServerError, sayHello(c)))
defer server.Close()
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(status.Code(err), ShouldEqual, codes.Internal)
So(log, shouldHaveMessagesLike,
expectedCallLogEntry(client),
memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
expectedCallLogEntry(client),
memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
expectedCallLogEntry(client),
memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
expectedCallLogEntry(client),
memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed permanently"},
)
})
Convey("HTTP 503 without gRPC header", func(c C) {
client, server := setUp(transientErrors(10, false, http.StatusServiceUnavailable, sayHello(c)))
defer server.Close()
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(status.Code(err), ShouldEqual, codes.Unavailable)
})
Convey("Forbidden", func(c C) {
client, server := setUp(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(HeaderGRPCCode, strconv.Itoa(int(codes.PermissionDenied)))
w.WriteHeader(http.StatusForbidden)
fmt.Fprintln(w, "Access denied")
})
defer server.Close()
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(status.Code(err), ShouldEqual, codes.PermissionDenied)
So(status.Convert(err).Message(), ShouldEqual, "Access denied")
So(log, shouldHaveMessagesLike,
expectedCallLogEntry(client),
memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed permanently"},
)
})
Convey(HeaderGRPCCode, func(c C) {
client, server := setUp(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(HeaderGRPCCode, strconv.Itoa(int(codes.Canceled)))
w.WriteHeader(http.StatusBadRequest)
})
defer server.Close()
err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
So(status.Code(err), ShouldEqual, codes.Canceled)
})
Convey("Concurrency limit", func(c C) {
const (
maxConcurrentRequests = 3
totalRequests = 10
)
cur := int64(0)
reports := make(chan int64, totalRequests)
// For each request record how many parallel requests were running at
// the same time.
client, server := setUp(func(w http.ResponseWriter, r *http.Request) {
reports <- atomic.AddInt64(&cur, 1)
defer atomic.AddInt64(&cur, -1)
// Note: dependence on the real clock is racy, but in the worse case
// (if client.Call guts are extremely slow) we'll get a false positive
// result. In other words, if the code under test is correct (and it
// is right now), the test will always succeed no matter what. If the
// code under test is not correct (i.e. regresses), we'll start seeing
// test errors most of the time, with occasional false successes.
time.Sleep(200 * time.Millisecond)
sayHello(c)(w, r)
})
defer server.Close()
client.MaxConcurrentRequests = maxConcurrentRequests
// Execute a bunch of requests concurrently.
wg := sync.WaitGroup{}
for i := 0; i < totalRequests; i++ {
wg.Add(1)
go func() {
defer wg.Done()
err := client.Call(ctx, "prpc.Greeter", "SayHello", &HelloRequest{Name: "John"}, &HelloReply{})
c.So(err, ShouldBeNil)
}()
}
wg.Wait()
// Make sure concurrency limit wasn't violated.
for i := 0; i < totalRequests; i++ {
select {
case concur := <-reports:
So(concur, ShouldBeLessThanOrEqualTo, maxConcurrentRequests)
default:
t.Fatal("Some requests didn't execute")
}
}
})
})
})
}