blob: 7cf4f2928b7dbd748679e101b29115adbdf3ff28 [file] [log] [blame]
// Copyright 2019 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"os"
"sync"
"sync/atomic"
"testing"
"time"
"go.chromium.org/luci/common/clock/testclock"
"go.chromium.org/luci/common/logging"
"go.chromium.org/luci/common/logging/gkelogger"
"go.chromium.org/luci/server/router"
"go.chromium.org/luci/server/secrets"
. "github.com/smartystreets/goconvey/convey"
)
func TestServer(t *testing.T) {
t.Parallel()
Convey("Works", t, func() {
ctx, tc := testclock.UseTime(context.Background(), testclock.TestRecentTimeUTC)
srv, err := newTestServer(ctx)
So(err, ShouldBeNil)
defer srv.cleanup()
srv.ServeInBackground()
Reset(func() { So(srv.StopBackgroundServing(), ShouldBeNil) })
Convey("Logging", func() {
srv.Routes.GET("/test", router.MiddlewareChain{}, func(c *router.Context) {
logging.Infof(c.Context, "Info log")
tc.Add(time.Second)
logging.Warningf(c.Context, "Warn log")
c.Writer.WriteHeader(201)
c.Writer.Write([]byte("Hello, world"))
})
resp, err := srv.Get("/test", map[string]string{
"User-Agent": "Test-user-agent",
"X-Cloud-Trace-Context": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa/00001;trace=TRUE",
"X-Forwarded-For": "1.1.1.1,2.2.2.2,3.3.3.3",
})
So(err, ShouldBeNil)
So(resp, ShouldEqual, "Hello, world")
// Stderr log captures details about the request.
So(srv.stderr.Last(1), ShouldResemble, []gkelogger.LogEntry{
{
Severity: "warning",
Time: "1454472307.7",
TraceID: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
RequestInfo: &gkelogger.RequestInfo{
Method: "GET",
URL: srv.mainAddr + "/test",
Status: 201,
RequestSize: "0",
ResponseSize: "12", // len("Hello, world")
UserAgent: "Test-user-agent",
RemoteIP: "2.2.2.2",
Latency: "1.000000s",
},
},
})
// Stdout log captures individual log lines.
So(srv.stdout.Last(2), ShouldResemble, []gkelogger.LogEntry{
{
Severity: "info",
Message: "Info log",
Time: "1454472306.7",
TraceID: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
Operation: &gkelogger.Operation{
ID: "9566c74d10037c4d7bbb0407d1e2c649",
},
},
{
Severity: "warning",
Message: "Warn log",
Time: "1454472307.7",
TraceID: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
Operation: &gkelogger.Operation{
ID: "9566c74d10037c4d7bbb0407d1e2c649",
},
},
})
})
Convey("Secrets", func() {
srv.Routes.GET("/secret", router.MiddlewareChain{}, func(c *router.Context) {
s, err := secrets.GetSecret(c.Context, "secret_name")
if err != nil {
c.Writer.WriteHeader(500)
} else {
c.Writer.Write([]byte(s.Current))
}
})
resp, err := srv.Get("/secret", nil)
So(err, ShouldBeNil)
So(resp, ShouldNotBeEmpty)
})
})
}
func BenchmarkServer(b *testing.B) {
srv, err := newTestServer(context.Background())
if err != nil {
b.Fatal(err)
}
defer srv.cleanup()
// The route we are going to hit from the benchmark.
srv.Routes.GET("/test", router.MiddlewareChain{}, func(c *router.Context) {
logging.Infof(c.Context, "Hello, world")
secrets.GetSecret(c.Context, "key-name") // e.g. checking XSRF token
c.Writer.Write([]byte("Hello, world"))
})
// Don't actually store logs from all many-many iterations of the loop below.
srv.stdout.discard = true
srv.stderr.discard = true
// Launch the server and wait for it to start serving to make sure all guts
// are initialized.
srv.ServeInBackground()
defer srv.StopBackgroundServing()
if _, err = srv.Get("/health", nil); err != nil {
b.Fatal(err)
}
// Actual benchmark loop. Note that we bypass network layer here completely
// (by not using http.DefaultClient).
b.ResetTimer()
for n := 0; n < b.N; n++ {
req, err := http.NewRequest("GET", "/test", nil)
if err != nil {
b.Fatal(err)
}
req.Header.Set("X-Cloud-Trace-Context", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa/00001;trace=TRUE")
req.Header.Set("X-Forwarded-For", "1.1.1.1,2.2.2.2,3.3.3.3")
rr := httptest.NewRecorder()
srv.Routes.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
b.Fatalf("unexpected status %d", rr.Code)
}
}
}
////////////////////////////////////////////////////////////////////////////////
type testServer struct {
*Server
stdout logsRecorder
stderr logsRecorder
mainAddr string
cleanup func()
serveErr errorEvent
}
func newTestServer(ctx context.Context) (*testServer, error) {
srv := &testServer{
serveErr: errorEvent{signal: make(chan struct{})},
}
tmpSecret, err := tempSecret()
if err != nil {
return nil, err
}
srv.cleanup = func() { os.Remove(tmpSecret.Name()) }
srv.Server = New(Options{
Prod: true,
HTTPAddr: "main_addr",
AdminAddr: "admin_addr",
RootSecretPath: tmpSecret.Name(),
testCtx: ctx,
testSeed: 1,
testStdout: &srv.stdout,
testStderr: &srv.stderr,
// Bind to auto-assigned ports.
testListeners: map[string]net.Listener{
"main_addr": setupListener(),
"admin_addr": setupListener(),
},
})
mainPort := srv.opts.testListeners["main_addr"].Addr().(*net.TCPAddr).Port
srv.mainAddr = fmt.Sprintf("http://127.0.0.1:%d", mainPort)
return srv, nil
}
func (s *testServer) ServeInBackground() {
go func() { s.serveErr.Set(s.ListenAndServe()) }()
}
func (s *testServer) StopBackgroundServing() error {
s.Shutdown()
return s.serveErr.Get()
}
// Get makes a blocking request, aborting it if the server dies.
func (s *testServer) Get(uri string, headers map[string]string) (resp string, err error) {
done := make(chan struct{})
go func() {
defer close(done)
var req *http.Request
if req, err = http.NewRequest("GET", s.mainAddr+uri, nil); err != nil {
return
}
for k, v := range headers {
req.Header.Set(k, v)
}
var res *http.Response
if res, err = http.DefaultClient.Do(req); err != nil {
return
}
defer res.Body.Close()
var blob []byte
if blob, err = ioutil.ReadAll(res.Body); err != nil {
return
}
if res.StatusCode >= 400 {
err = fmt.Errorf("unexpected status %d", res.StatusCode)
}
resp = string(blob)
}()
select {
case <-s.serveErr.signal:
err = s.serveErr.Get()
case <-done:
}
return
}
////////////////////////////////////////////////////////////////////////////////
func tempSecret() (out *os.File, err error) {
var f *os.File
defer func() {
if f != nil && err != nil {
os.Remove(f.Name())
}
}()
f, err = ioutil.TempFile("", "luci-server-test")
if err != nil {
return nil, err
}
secret := secrets.Secret{Current: []byte("test secret")}
if err := json.NewEncoder(f).Encode(&secret); err != nil {
return nil, err
}
return f, f.Close()
}
func setupListener() net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
panic(err)
}
return l
}
////////////////////////////////////////////////////////////////////////////////
type errorEvent struct {
err atomic.Value
signal chan struct{} // closed after 'err' is populated
}
func (e *errorEvent) Set(err error) {
if err != nil {
e.err.Store(err)
}
close(e.signal)
}
func (e *errorEvent) Get() error {
<-e.signal
err, _ := e.err.Load().(error)
return err
}
////////////////////////////////////////////////////////////////////////////////
type logsRecorder struct {
discard bool
m sync.Mutex
logs []gkelogger.LogEntry
}
func (r *logsRecorder) Write(e *gkelogger.LogEntry) {
if r.discard {
return
}
r.m.Lock()
r.logs = append(r.logs, *e)
r.m.Unlock()
}
func (r *logsRecorder) Last(n int) []gkelogger.LogEntry {
entries := make([]gkelogger.LogEntry, n)
r.m.Lock()
copy(entries, r.logs[len(r.logs)-n:])
r.m.Unlock()
return entries
}