Merge remote-tracking branch 'origin' into contextdecorator
diff --git a/endpoints/auth.go b/endpoints/auth.go
index de01240..714a3ac 100644
--- a/endpoints/auth.go
+++ b/endpoints/auth.go
@@ -102,11 +102,18 @@
errNoRequest = errors.New("no request for context (use endpoints.NewContext to create a context)")
)
+// ContextDecorator will be called as the last step of the creation of a new context.
+// If nil the context will not be decorated.
+var ContextDecorator func(context.Context) context.Context
+
// NewContext returns a new context for an in-flight API (HTTP) request.
func NewContext(r *http.Request) context.Context {
c := appengine.NewContext(r)
c = context.WithValue(c, requestKey, r)
c = context.WithValue(c, authenticatorKey, AuthenticatorFactory())
+ if ContextDecorator != nil {
+ c = ContextDecorator(c)
+ }
return c
}
diff --git a/endpoints/auth_test.go b/endpoints/auth_test.go
index 8b0b8dc..e323865 100644
--- a/endpoints/auth_test.go
+++ b/endpoints/auth_test.go
@@ -432,3 +432,27 @@
AuthenticatorFactory = factory
return NewContext(r)
}
+
+func TestContextDecorator(t *testing.T) {
+ defer func(old func(context.Context) context.Context) {
+ ContextDecorator = old
+ }(ContextDecorator)
+
+ inst, err := aetest.NewInstance(nil)
+ if err != nil {
+ t.Fatalf("failed to create instance: %v", err)
+ }
+ defer inst.Close()
+
+ key := "this is the key"
+ value := new(int)
+ ContextDecorator = func(ctx context.Context) context.Context {
+ return context.WithValue(ctx, key, value)
+ }
+
+ r, _ := inst.NewRequest("GET", "/", nil)
+ ctx := NewContext(r)
+ if got := ctx.Value(key); got != value {
+ t.Errorf("expected value in context was %v; got %v", value, got)
+ }
+}