blob: 5c3cc1c1877cdb2c28ffc0b70367121fa66b22f7 [file] [log] [blame]
// Copyright 2016 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package router
import (
"context"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
. "github.com/smartystreets/goconvey/convey"
)
func TestRouter(t *testing.T) {
t.Parallel()
outputKey := "output"
client := &http.Client{}
appendValue := func(c context.Context, key string, val string) context.Context {
var current []string
if v := c.Value(key); v != nil {
current = v.([]string)
}
return context.WithValue(c, key, append(current, val))
}
a := func(c *Context, next Handler) {
c.Context = appendValue(c.Context, outputKey, "a:before")
next(c)
c.Context = appendValue(c.Context, outputKey, "a:after")
}
b := func(c *Context, next Handler) {
c.Context = appendValue(c.Context, outputKey, "b:before")
next(c)
c.Context = appendValue(c.Context, outputKey, "b:after")
}
c := func(c *Context, next Handler) {
c.Context = appendValue(c.Context, outputKey, "c")
next(c)
}
d := func(c *Context, next Handler) {
next(c)
c.Context = appendValue(c.Context, outputKey, "d")
}
stop := func(_ *Context, _ Handler) {}
handler := func(c *Context) {
c.Context = appendValue(c.Context, outputKey, "handler")
}
Convey("Router", t, func() {
r := New()
Convey("New", func() {
Convey("Should initialize non-nil httprouter.Router", func() {
So(r.hrouter, ShouldNotBeNil)
})
})
Convey("Use", func() {
Convey("Should append middleware", func() {
So(len(r.middleware), ShouldEqual, 0)
r.Use(NewMiddlewareChain(a, b))
So(len(r.middleware), ShouldEqual, 2)
})
})
Convey("Subrouter", func() {
Convey("Should create new router with values from original router", func() {
r.BasePath = "/foo"
r.Use(NewMiddlewareChain(a, b))
r2 := r.Subrouter("bar")
So(r.hrouter, ShouldPointTo, r2.hrouter)
So(r2.BasePath, ShouldEqual, "/foo/bar")
So(r.middleware, ShouldResemble, r2.middleware)
})
})
Convey("Handle", func() {
Convey("Should not modify existing empty r.middleware slice", func() {
So(len(r.middleware), ShouldEqual, 0)
r.Handle("GET", "/bar", NewMiddlewareChain(b, c), handler)
So(len(r.middleware), ShouldEqual, 0)
})
Convey("Should not modify existing r.middleware slice", func() {
r.Use(NewMiddlewareChain(a))
So(len(r.middleware), ShouldEqual, 1)
r.Handle("GET", "/bar", NewMiddlewareChain(b, c), handler)
So(len(r.middleware), ShouldEqual, 1)
})
})
Convey("run", func() {
ctx := &Context{Context: context.Background()}
Convey("Should execute handler when using nil middlewares", func() {
run(ctx, nil, nil, handler)
So(ctx.Context.Value(outputKey), ShouldResemble, []string{"handler"})
})
Convey("Should execute middlewares and handler in order", func() {
m := NewMiddlewareChain(a, b, c)
n := NewMiddlewareChain(d)
run(ctx, m, n, handler)
So(ctx.Context.Value(outputKey), ShouldResemble,
[]string{"a:before", "b:before", "c", "handler", "d", "b:after", "a:after"},
)
})
Convey("Should not execute upcoming middleware/handlers if next is not called", func() {
mc := NewMiddlewareChain(a, stop, b)
run(ctx, mc, NewMiddlewareChain(), handler)
So(ctx.Context.Value(outputKey), ShouldResemble, []string{"a:before", "a:after"})
})
Convey("Should execute next middleware when it encounters nil middleware", func() {
Convey("At start of first chain", func() {
run(ctx, NewMiddlewareChain(nil, a), NewMiddlewareChain(b), handler)
So(ctx.Context.Value(outputKey), ShouldResemble, []string{"a:before", "b:before", "handler", "b:after", "a:after"})
})
Convey("At start of second chain", func() {
run(ctx, NewMiddlewareChain(a), NewMiddlewareChain(nil, b), handler)
So(ctx.Context.Value(outputKey), ShouldResemble, []string{"a:before", "b:before", "handler", "b:after", "a:after"})
})
Convey("At end of first chain", func() {
run(ctx, NewMiddlewareChain(a, nil), NewMiddlewareChain(b), handler)
So(ctx.Context.Value(outputKey), ShouldResemble, []string{"a:before", "b:before", "handler", "b:after", "a:after"})
})
Convey("At end of second chain", func() {
run(ctx, NewMiddlewareChain(a), NewMiddlewareChain(b, nil), handler)
So(ctx.Context.Value(outputKey), ShouldResemble, []string{"a:before", "b:before", "handler", "b:after", "a:after"})
})
})
})
Convey("ServeHTTP", func() {
ts := httptest.NewServer(r)
a := func(c *Context, next Handler) {
c.Context = appendValue(c.Context, outputKey, "a:before")
next(c)
c.Context = appendValue(c.Context, outputKey, "a:after")
io.WriteString(c.Writer, strings.Join(c.Context.Value(outputKey).([]string), ","))
}
Convey("Should execute middleware registered in Use and Handle in order", func() {
r.Use(NewMiddlewareChain(a))
r.GET("/ab", NewMiddlewareChain(b), handler)
res, err := client.Get(ts.URL + "/ab")
So(err, ShouldBeNil)
defer res.Body.Close()
So(res.StatusCode, ShouldEqual, http.StatusOK)
p, err := ioutil.ReadAll(res.Body)
So(err, ShouldBeNil)
So(string(p), ShouldEqual, strings.Join(
[]string{"a:before", "b:before", "handler", "b:after", "a:after"},
",",
))
})
Convey("Should return method not allowed for existing path but wrong method in request", func() {
r.POST("/comment", NewMiddlewareChain(), handler)
res, err := client.Get(ts.URL + "/comment")
So(err, ShouldBeNil)
defer res.Body.Close()
So(res.StatusCode, ShouldEqual, http.StatusMethodNotAllowed)
})
Convey("Should return expected response from handler", func() {
handler := func(c *Context) {
c.Writer.Write([]byte("Hello, " + c.Params[0].Value))
}
r.GET("/hello/:name", NewMiddlewareChain(c, d), handler)
res, err := client.Get(ts.URL + "/hello/世界")
So(err, ShouldBeNil)
defer res.Body.Close()
So(res.StatusCode, ShouldEqual, http.StatusOK)
p, err := ioutil.ReadAll(res.Body)
So(err, ShouldBeNil)
So(string(p), ShouldEqual, "Hello, 世界")
})
Reset(func() {
ts.Close()
})
})
Convey("makeBasePath", func() {
cases := []struct{ base, relative, result string }{
{"/", "", "/"},
{"", "", "/"},
{"foo", "", "/foo"},
{"foo/", "", "/foo/"},
{"foo", "/", "/foo/"},
{"foo", "bar", "/foo/bar"},
{"foo/", "bar", "/foo/bar"},
{"foo", "bar/", "/foo/bar/"},
{"/foo", "/bar", "/foo/bar"},
{"/foo/", "/bar", "/foo/bar"},
{"foo//", "///bar", "/foo/bar"},
{"foo", "bar///baz/qux", "/foo/bar/baz/qux"},
{"//foo//", "///bar///baz/qux/", "/foo/bar/baz/qux/"},
}
for _, c := range cases {
So(makeBasePath(c.base, c.relative), ShouldEqual, c.result)
}
})
})
}