blob: 7a26e5d852f7ac37200617212f68bbe3a7c0bf3a [file] [log] [blame]
// Copyright 2019 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package rpc
import (
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"os/signal"
"strconv"
"sync"
"golang.org/x/sys/unix"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/reflection"
"chromiumos/tast/errors"
"chromiumos/tast/internal/logging"
"chromiumos/tast/internal/protocol"
"chromiumos/tast/internal/testcontext"
"chromiumos/tast/internal/testing"
"chromiumos/tast/internal/timing"
)
// RunServer runs a gRPC server on r/w channels.
// register is called back to register core services. svcs is a list of
// user-defined gRPC services to be registered if the client requests them in
// HandshakeRequest.
// RunServer blocks until the client connection is closed or it encounters an
// error.
func RunServer(r io.Reader, w io.Writer, svcs []*testing.Service, register func(srv *grpc.Server, req *protocol.HandshakeRequest) error) error {
// In case w is stdout or stderr, writing data to it after it is closed
// causes SIGPIPE to be delivered to the process, which by default
// terminates the process without running deferred cleanup calls.
// To avoid the issue, ignore SIGPIPE while running the gRPC server.
// See https://golang.org/pkg/os/signal/#hdr-SIGPIPE for more details.
signal.Ignore(unix.SIGPIPE)
defer signal.Reset(unix.SIGPIPE)
var req protocol.HandshakeRequest
if err := receiveRawMessage(r, &req); err != nil {
return err
}
// Make sure to return only after all active method calls finish.
// Otherwise the process can exit before running deferred function
// calls on service goroutines.
var calls sync.WaitGroup
defer calls.Wait()
// Start a remote logging server. It is used to forward logs from
// user-defined gRPC services via side channels.
ls := newRemoteLoggingServer()
srv := grpc.NewServer(serverOpts(ls, &calls)...)
// Register core services.
regErr := registerCoreServices(srv, ls, &req, register)
// Create a server-scoped context.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Register user-defined gRPC services if requested.
if req.GetNeedUserServices() {
registerUserServices(ctx, srv, ls, &req, svcs, false)
}
if regErr != nil {
err := errors.Wrap(regErr, "gRPC server initialization failed")
res := &protocol.HandshakeResponse{
Error: &protocol.HandshakeError{
Reason: fmt.Sprintf("gRPC server initialization failed: %v", err),
},
}
sendRawMessage(w, res)
return err
}
if err := sendRawMessage(w, &protocol.HandshakeResponse{}); err != nil {
return err
}
// From now on, catch SIGINT/SIGTERM to stop the server gracefully.
sigCh := make(chan os.Signal, 1)
defer close(sigCh)
signal.Notify(sigCh, unix.SIGINT, unix.SIGTERM)
defer signal.Stop(sigCh)
sigErrCh := make(chan error, 1)
go func() {
if sig, ok := <-sigCh; ok {
sigErrCh <- errors.Errorf("caught signal %d (%s)", sig, sig)
srv.Stop()
}
}()
if err := srv.Serve(NewPipeListener(r, w)); err != nil && err != io.EOF {
// Replace the error if we saw a signal.
select {
case err := <-sigErrCh:
return err
default:
}
return err
}
return nil
}
// RunTCPServer runs a gRPC server listening on the specified port thought TCP
// Port contains the TCP port number where gRPC server listens to
// HandshakeRequest contains parameters needed to initialize a gRPC server.
// svcs is the candidate list of user-defined gRPC services and they will be
// registered if GuaranteeCompatibility is set.
func RunTCPServer(port int, handshakeReq *protocol.HandshakeRequest, svcs []*testing.Service,
register func(srv *grpc.Server, req *protocol.HandshakeRequest) error) error {
// Make sure to return only after all active method calls finish.
// Otherwise the process can exit before running deferred function
// calls on service goroutines.
var calls sync.WaitGroup
defer calls.Wait()
// Start a remote logging server. It is used to forward logs from
// user-defined gRPC services via side channels.
ls := newRemoteLoggingServer()
srv := grpc.NewServer(serverOpts(ls, &calls)...)
// Register core services.
regErr := registerCoreServices(srv, ls, handshakeReq, register)
if regErr != nil {
return errors.Wrap(regErr, "gRPC server initialization failed")
}
// Create a server-scoped context.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Register user-defined gRPC services intended for public use.
registerUserServices(ctx, srv, ls, handshakeReq, svcs, true)
// From now on, catch SIGINT/SIGTERM to stop the server gracefully.
sigCh := make(chan os.Signal, 1)
defer close(sigCh)
signal.Notify(sigCh, unix.SIGINT, unix.SIGTERM)
defer signal.Stop(sigCh)
sigErrCh := make(chan error, 1)
go func() {
if sig, ok := <-sigCh; ok {
sigErrCh <- errors.Errorf("caught signal %d (%s)", sig, sig)
srv.Stop()
}
}()
// start gRPC server listening on the tcp port
listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", port))
if err != nil {
return errors.Wrap(err, "server failed to listen")
}
if err := srv.Serve(listener); err != nil && err != io.EOF {
// Replace the error if we saw a signal.
select {
case err := <-sigErrCh:
return err
default:
}
return err
}
return nil
}
// serverStreamWithContext wraps grpc.ServerStream with overriding Context.
type serverStreamWithContext struct {
grpc.ServerStream
ctx context.Context
}
// Context overrides grpc.ServerStream.Context.
func (s *serverStreamWithContext) Context() context.Context {
return s.ctx
}
var _ grpc.ServerStream = (*serverStreamWithContext)(nil)
// serverOpts returns gRPC server-side interceptors to manipulate context.
func serverOpts(ls *remoteLoggingServer, calls *sync.WaitGroup) []grpc.ServerOption {
// hook is called on every gRPC method call.
// It returns a Context to be passed to a gRPC method, a function to be
// called on the end of the gRPC method call to compute trailers, and
// possibly an error.
hook := func(ctx context.Context, method string) (context.Context, func() metadata.MD, error) {
// Forward all uncaptured logs via LoggingService.
ctx = logging.AttachLogger(ctx, logging.NewSinkLogger(logging.LevelInfo, false, logging.NewFuncSink(ls.Log)))
var outDir string
var tl *timing.Log
if isUserMethod(method) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, nil, errors.New("metadata not available")
}
var err error
outDir, err = ioutil.TempDir("", "rpc-outdir.")
if err != nil {
return nil, nil, err
}
// Make the directory world-writable so that tests can create files as other users,
// and set the sticky bit to prevent users from deleting other users' files.
if err := os.Chmod(outDir, 0777|os.ModeSticky); err != nil {
return nil, nil, err
}
ctx = testcontext.WithCurrentEntity(ctx, incomingCurrentContext(md, outDir))
tl = timing.NewLog()
ctx = timing.NewContext(ctx, tl)
}
trailer := func() metadata.MD {
md := make(metadata.MD)
if isUserMethod(method) {
b, err := json.Marshal(tl)
if err != nil {
logging.Info(ctx, "Failed to marshal timing JSON: ", err)
} else {
md[metadataTiming] = []string{string(b)}
}
// Send metadataOutDir only if some files were saved in order to avoid extra round-trips.
if fis, err := ioutil.ReadDir(outDir); err != nil {
logging.Info(ctx, "gRPC output directory is corrupted: ", err)
} else if len(fis) == 0 {
if err := os.RemoveAll(outDir); err != nil {
logging.Info(ctx, "Failed to remove gRPC output directory: ", err)
}
} else {
md[metadataOutDir] = []string{outDir}
}
}
if !isLoggingMethod(method) {
md[metadataLogLastSeq] = []string{strconv.FormatUint(ls.LastSeq(), 10)}
}
return md
}
return ctx, trailer, nil
}
return []grpc.ServerOption{
grpc.UnaryInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (res interface{}, err error) {
calls.Add(1)
defer calls.Done()
ctx, trailer, err := hook(ctx, info.FullMethod)
if err != nil {
return nil, err
}
defer func() {
grpc.SetTrailer(ctx, trailer())
}()
return handler(ctx, req)
}),
grpc.StreamInterceptor(func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
calls.Add(1)
defer calls.Done()
ctx, trailer, err := hook(stream.Context(), info.FullMethod)
if err != nil {
return err
}
stream = &serverStreamWithContext{stream, ctx}
defer func() {
stream.SetTrailer(trailer())
}()
return handler(srv, stream)
}),
}
}
// registerCoreServices registers core Tast services.
// srv is the gRPC server instance
// ls is the remote logging server that forwards logs through side channel
// HandshakeRequest contains parameters needed to initialize a gRPC server.
// svcs is the candidate list of user-defined gRPC services to be registered
// register offers a callback hook for additional service registration
func registerCoreServices(srv *grpc.Server, ls *remoteLoggingServer,
handshakeReq *protocol.HandshakeRequest, register func(srv *grpc.Server, req *protocol.HandshakeRequest) error) error {
reflection.Register(srv)
protocol.RegisterLoggingServer(srv, ls)
protocol.RegisterFileTransferServer(srv, newFileTransferServer())
return register(srv, handshakeReq)
}
// registerUserServices registers user defined gRPC services to the gRPC Server
// srv is the gRPC server instance
// ls is the remote logging server that forwards logs through side channel
// HandshakeRequest contains parameters needed to initialize a gRPC server.
// svcs is the candidate list of user-defined gRPC services to be registered
// guaranteeCompatibilityOnly determines if the service registration is restricted
// only to services with GuaranteeCompatibility set
func registerUserServices(ctx context.Context, srv *grpc.Server, ls *remoteLoggingServer,
handshakeReq *protocol.HandshakeRequest, svcs []*testing.Service, guaranteeCompatibilityOnly bool) error {
logger := logging.NewSinkLogger(logging.LevelInfo, false, logging.NewFuncSink(ls.Log))
ctx = logging.AttachLogger(ctx, logger)
vars := handshakeReq.GetBundleInitParams().GetVars()
for _, svc := range svcs {
if !guaranteeCompatibilityOnly || svc.GuaranteeCompatibility {
svc.Register(srv, testing.NewServiceState(ctx, testing.NewServiceRoot(svc, vars)))
}
}
return nil
}
// startServing kicks off the gRPC server listening through the listener
func startServing(srv *grpc.Server, listener net.Listener) error {
// From now on, catch SIGINT/SIGTERM to stop the server gracefully.
sigCh := make(chan os.Signal, 1)
defer close(sigCh)
signal.Notify(sigCh, unix.SIGINT, unix.SIGTERM)
defer signal.Stop(sigCh)
sigErrCh := make(chan error, 1)
go func() {
if sig, ok := <-sigCh; ok {
sigErrCh <- errors.Errorf("caught signal %d (%s)", sig, sig)
srv.Stop()
}
}()
if err := srv.Serve(listener); err != nil {
// Replace the error if we saw a signal.
select {
case err := <-sigErrCh:
return err
default:
}
return err
}
return nil
}