Merge pull request #126 from GoogleCloudPlatform/contextdecorator
Add an extension mechanism to decorate contexts
diff --git a/endpoints/server.go b/endpoints/server.go
index 002026b..59a69c6 100644
--- a/endpoints/server.go
+++ b/endpoints/server.go
@@ -12,6 +12,8 @@
"net/http"
"reflect"
"strings"
+
+ "golang.org/x/net/context"
// Mainly for debug logging
"io/ioutil"
@@ -22,6 +24,10 @@
type Server struct {
root string
services *serviceMap
+
+ // ContextDecorator will be called as the last step of the creation of a new context.
+ // If nil the context will not be decorated.
+ ContextDecorator func(context.Context) (context.Context, error)
}
// NewServer returns a new RPC server.
@@ -99,6 +105,14 @@
// ServeHTTP is Server's implementation of http.Handler interface.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c := NewContext(r)
+ if s.ContextDecorator != nil {
+ ctx, err := s.ContextDecorator(c)
+ if err != nil {
+ writeError(w, err)
+ return
+ }
+ c = ctx
+ }
// Always respond with JSON, even when an error occurs.
// Note: API server doesn't expect an encoding in Content-Type header.
diff --git a/endpoints/server_test.go b/endpoints/server_test.go
index 9b6f074..397db57 100644
--- a/endpoints/server_test.go
+++ b/endpoints/server_test.go
@@ -411,3 +411,52 @@
t.Fatalf("expected %q; got %q", body, res)
}
}
+
+const (
+ contextDecoratorKey = "context_decorator_key"
+ contextDecoratorValue = "context_decorator_value"
+)
+
+func (s *ServerTestService) ContextDecorator(ctx context.Context) (*VoidMessage, error) {
+ fmt.Println("ContextDecorator called")
+ if got := ctx.Value(contextDecoratorKey); got != contextDecoratorValue {
+ return nil, NewBadRequestError("wrong context value: %q", got)
+ }
+ return &VoidMessage{}, nil
+}
+
+func TestContextDecorator(t *testing.T) {
+ server := createAPIServer()
+ inst, err := aetest.NewInstance(nil)
+ if err != nil {
+ t.Fatalf("failed to create instance: %v", err)
+ }
+ defer inst.Close()
+
+ server.ContextDecorator = func(ctx context.Context) (context.Context, error) {
+ return nil, ConflictError
+ }
+ path := "/ServerTestService.ContextDecorator"
+ r, _ := inst.NewRequest("GET", path, strings.NewReader(""))
+ w := httptest.NewRecorder()
+ server.ServeHTTP(w, r)
+ if w.Code != http.StatusConflict {
+ t.Errorf("expected status code Conflict (409); got %v", w.Code)
+ msg, _ := ioutil.ReadAll(w.Body)
+ t.Errorf("response body: %s", msg)
+ }
+
+ server.ContextDecorator = func(ctx context.Context) (context.Context, error) {
+ fmt.Println("context decorated")
+ return context.WithValue(ctx, contextDecoratorKey, contextDecoratorValue), nil
+ }
+
+ r, _ = inst.NewRequest("POST", path, strings.NewReader("{}"))
+ w = httptest.NewRecorder()
+ server.ServeHTTP(w, r)
+ if w.Code != http.StatusOK {
+ t.Errorf("expected status OK (200); got %v", w.Code)
+ msg, _ := ioutil.ReadAll(w.Body)
+ t.Errorf("response body: %s", msg)
+ }
+}