blob: 9ab71c6eb4dc37d6771296c1f17dae6062f7426e [file] [log] [blame]
// Copyright 2020 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package grpcutil
import (
"context"
"google.golang.org/grpc"
)
// ChainUnaryServerInterceptors chains multiple unary interceptors together.
//
// The first one becomes the outermost, and the last one becomes the
// innermost, i.e. `ChainUnaryServerInterceptors(a, b, c)(h) === a(b(c(h)))`.
//
// nil-valued interceptors are silently skipped.
func ChainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
switch {
case len(interceptors) == 0:
// Noop interceptor.
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
return handler(ctx, req)
}
case interceptors[0] == nil:
// Skip nils.
return ChainUnaryServerInterceptors(interceptors[1:]...)
case len(interceptors) == 1:
// No need to actually chain anything.
return interceptors[0]
default:
return unaryCombinator(interceptors[0], ChainUnaryServerInterceptors(interceptors[1:]...))
}
}
// unaryCombinator is an interceptor that chains just two interceptors together.
func unaryCombinator(first, second grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
return first(ctx, req, info, func(ctx context.Context, req any) (any, error) {
return second(ctx, req, info, handler)
})
}
}
// ChainStreamServerInterceptors chains multiple stream interceptors together.
//
// The first one becomes the outermost, and the last one becomes the
// innermost, i.e. `ChainStreamServerInterceptors(a, b, c)(h) === a(b(c(h)))`.
//
// nil-valued interceptors are silently skipped.
func ChainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
switch {
case len(interceptors) == 0:
// Noop interceptor.
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return handler(srv, ss)
}
case interceptors[0] == nil:
// Skip nils.
return ChainStreamServerInterceptors(interceptors[1:]...)
case len(interceptors) == 1:
// No need to actually chain anything.
return interceptors[0]
default:
return streamCombinator(interceptors[0], ChainStreamServerInterceptors(interceptors[1:]...))
}
}
// unaryCombinator is an interceptor that chains just two interceptors together.
func streamCombinator(first, second grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return first(srv, ss, info, func(srv any, ss grpc.ServerStream) error {
return second(srv, ss, info, handler)
})
}
}
// ModifyServerStreamContext returns a ServerStream that fully wraps the given
// one except its context is modified based on the result of the given callback.
//
// This is handy when implementing stream server interceptors that need to
// put stuff into the stream's context.
//
// The callback will be called immediately and only once. It must return a
// context derived from the context it receives or nil if the context
// modification is not actually necessary.
func ModifyServerStreamContext(ss grpc.ServerStream, cb func(context.Context) context.Context) grpc.ServerStream {
original := ss.Context()
modified := cb(original)
if modified == nil || modified == original {
return ss
}
return &wrappedSS{ss, modified}
}
// wrappedSS is a grpc.ServerStream that replaces the context.
type wrappedSS struct {
grpc.ServerStream
ctx context.Context
}
// Context returns the context for this stream.
//
// This is part of grpc.ServerStream interface.
func (ss *wrappedSS) Context() context.Context {
return ss.ctx
}
// UnifiedServerInterceptor can be converted into an unary or stream server
// interceptor.
//
// Such interceptor can do something at the start of the request (in particular
// modify the request context) and do something with the request error at the
// end. It can also skip the request entirely by returning an error without
// calling the handler.
//
// It is handy when implementing simple interceptors that can be used as both
// unary and stream ones.
type UnifiedServerInterceptor func(ctx context.Context, fullMethod string, handler func(ctx context.Context) error) error
// Unary returns an unary form of the interceptor.
func (u UnifiedServerInterceptor) Unary() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
var resp any
err := u(ctx, info.FullMethod, func(ctx context.Context) (err error) {
resp, err = handler(ctx, req)
return err
})
if err != nil {
resp = nil
}
return resp, err
}
}
// Stream returns a stream form of the interceptor.
func (u UnifiedServerInterceptor) Stream() grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
original := ss.Context()
return u(original, info.FullMethod, func(ctx context.Context) error {
var wrapped grpc.ServerStream
if ctx != original {
wrapped = &wrappedSS{ss, ctx}
} else {
wrapped = ss
}
return handler(srv, wrapped)
})
}
}