Merge pull request #126 from GoogleCloudPlatform/contextdecorator

Add an extension mechanism to decorate contexts
diff --git a/endpoints/server.go b/endpoints/server.go
index 002026b..59a69c6 100644
--- a/endpoints/server.go
+++ b/endpoints/server.go
@@ -12,6 +12,8 @@
 	"net/http"
 	"reflect"
 	"strings"
+
+	"golang.org/x/net/context"
 	// Mainly for debug logging
 	"io/ioutil"
 
@@ -22,6 +24,10 @@
 type Server struct {
 	root     string
 	services *serviceMap
+
+	// ContextDecorator will be called as the last step of the creation of a new context.
+	// If nil the context will not be decorated.
+	ContextDecorator func(context.Context) (context.Context, error)
 }
 
 // NewServer returns a new RPC server.
@@ -99,6 +105,14 @@
 // ServeHTTP is Server's implementation of http.Handler interface.
 func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	c := NewContext(r)
+	if s.ContextDecorator != nil {
+		ctx, err := s.ContextDecorator(c)
+		if err != nil {
+			writeError(w, err)
+			return
+		}
+		c = ctx
+	}
 
 	// Always respond with JSON, even when an error occurs.
 	// Note: API server doesn't expect an encoding in Content-Type header.
diff --git a/endpoints/server_test.go b/endpoints/server_test.go
index 9b6f074..397db57 100644
--- a/endpoints/server_test.go
+++ b/endpoints/server_test.go
@@ -411,3 +411,52 @@
 		t.Fatalf("expected %q; got %q", body, res)
 	}
 }
+
+const (
+	contextDecoratorKey   = "context_decorator_key"
+	contextDecoratorValue = "context_decorator_value"
+)
+
+func (s *ServerTestService) ContextDecorator(ctx context.Context) (*VoidMessage, error) {
+	fmt.Println("ContextDecorator called")
+	if got := ctx.Value(contextDecoratorKey); got != contextDecoratorValue {
+		return nil, NewBadRequestError("wrong context value: %q", got)
+	}
+	return &VoidMessage{}, nil
+}
+
+func TestContextDecorator(t *testing.T) {
+	server := createAPIServer()
+	inst, err := aetest.NewInstance(nil)
+	if err != nil {
+		t.Fatalf("failed to create instance: %v", err)
+	}
+	defer inst.Close()
+
+	server.ContextDecorator = func(ctx context.Context) (context.Context, error) {
+		return nil, ConflictError
+	}
+	path := "/ServerTestService.ContextDecorator"
+	r, _ := inst.NewRequest("GET", path, strings.NewReader(""))
+	w := httptest.NewRecorder()
+	server.ServeHTTP(w, r)
+	if w.Code != http.StatusConflict {
+		t.Errorf("expected status code Conflict (409); got %v", w.Code)
+		msg, _ := ioutil.ReadAll(w.Body)
+		t.Errorf("response body: %s", msg)
+	}
+
+	server.ContextDecorator = func(ctx context.Context) (context.Context, error) {
+		fmt.Println("context decorated")
+		return context.WithValue(ctx, contextDecoratorKey, contextDecoratorValue), nil
+	}
+
+	r, _ = inst.NewRequest("POST", path, strings.NewReader("{}"))
+	w = httptest.NewRecorder()
+	server.ServeHTTP(w, r)
+	if w.Code != http.StatusOK {
+		t.Errorf("expected status OK (200); got %v", w.Code)
+		msg, _ := ioutil.ReadAll(w.Body)
+		t.Errorf("response body: %s", msg)
+	}
+}