blob: ea9a93ea24045c940a2d5088267bf2af356a06e3 [file] [log] [blame]
// Copyright 2019 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package rpc
import (
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"reflect"
"strconv"
"strings"
"sync"
gotesting "testing"
"time"
"github.com/golang/protobuf/ptypes/empty"
"github.com/google/go-cmp/cmp"
"github.com/shirou/gopsutil/v3/process"
"golang.org/x/sys/unix"
"google.golang.org/grpc"
"go.chromium.org/tast/core/errors"
"go.chromium.org/tast/core/internal/fakeexec"
"go.chromium.org/tast/core/internal/logging"
"go.chromium.org/tast/core/internal/logging/loggingtest"
"go.chromium.org/tast/core/internal/protocol"
"go.chromium.org/tast/core/internal/sshtest"
"go.chromium.org/tast/core/internal/testcontext"
"go.chromium.org/tast/core/internal/testing"
"go.chromium.org/tast/core/internal/testingutil"
"go.chromium.org/tast/core/internal/timing"
"go.chromium.org/tast/core/ssh"
"go.chromium.org/tast/core/testutil"
)
const pingUserServiceName = "tast.coretest.PingUser"
// pingUserServer is an implementation of the Ping gRPC service.
type pingUserServer struct {
s *testing.ServiceState
// onPing is called when Ping is called by gRPC clients.
onPing func(context.Context, *testing.ServiceState) error
}
func (s *pingUserServer) Ping(ctx context.Context, _ *empty.Empty) (*empty.Empty, error) {
if err := s.onPing(ctx, s.s); err != nil {
return nil, err
}
return &empty.Empty{}, nil
}
type pingCoreServer struct{}
func (s *pingCoreServer) Ping(ctx context.Context, _ *empty.Empty) (*empty.Empty, error) {
return &empty.Empty{}, nil
}
type pingPanicServer struct{}
func (s *pingPanicServer) Ping(ctx context.Context, _ *empty.Empty) (*empty.Empty, error) {
panic("pingPanicServer.Ping was called")
}
// pingPair manages a local client/server pair of the Ping gRPC service.
type pingPair struct {
UserClient protocol.PingUserClient
CoreClient protocol.PingCoreClient
// The server is missing here; it is implicitly owned by the background
// goroutine that calls RunServer.
rpcClient *GenericClient // underlying gRPC connection
stopServer func() error // func to stop the gRPC server
}
// Close closes the gRPC connection and stops the gRPC server.
func (p *pingPair) Close() error {
firstErr := p.rpcClient.Close()
if err := p.stopServer(); firstErr == nil {
firstErr = err
}
return firstErr
}
// newPingService defines a new Ping service.
// onPing is called when Ping gRPC method is called on the server.
func newPingService(onPing func(context.Context, *testing.ServiceState) error) *testing.Service {
return &testing.Service{
Register: func(srv *grpc.Server, s *testing.ServiceState) {
protocol.RegisterPingUserServer(srv, &pingUserServer{s, onPing})
},
}
}
// newPingPair starts a local client/server pair of the Ping gRPC service.
//
// It panics if it fails to start a local client/server pair. Returned pingPair
// should be closed with pingPair.Close after its use.
func newPingPair(ctx context.Context, t *gotesting.T, req *protocol.HandshakeRequest, pingSvc *testing.Service) *pingPair {
t.Helper()
sr, cw := io.Pipe()
cr, sw := io.Pipe()
stopped := make(chan error, 1)
go func() {
stopped <- RunServer(sr, sw, []*testing.Service{pingSvc}, func(srv *grpc.Server, req *protocol.HandshakeRequest) error {
protocol.RegisterPingCoreServer(srv, &pingCoreServer{})
return nil
})
}()
stopServer := func() error {
// Close the client pipes. This will let the gRPC server close the singleton
// gRPC connection, which triggers the gRPC server to stop via PipeListener.
cw.Close()
cr.Close()
return <-stopped
}
success := false
defer func() {
if !success {
stopServer() // no error check; test has already failed
}
}()
cl, err := NewClient(ctx, cr, cw, req)
if err != nil {
t.Fatal("newClient failed: ", err)
}
success = true
return &pingPair{
UserClient: protocol.NewPingUserClient(cl.Conn()),
CoreClient: protocol.NewPingCoreClient(cl.Conn()),
rpcClient: cl,
stopServer: stopServer,
}
}
type channelSink struct {
ch chan<- string
}
func newChannelSink() (*channelSink, <-chan string) {
// Allocate an arbitrary large buffer to avoid unit tests from hanging
// when they don't read all messages.
ch := make(chan string, 1000)
return &channelSink{ch: ch}, ch
}
func (s *channelSink) Log(msg string) {
s.ch <- msg
}
func TestRPCSuccess(t *gotesting.T) {
ctx := testcontext.WithCurrentEntity(context.Background(), &testcontext.CurrentEntity{})
req := &protocol.HandshakeRequest{NeedUserServices: true}
called := false
svc := newPingService(func(context.Context, *testing.ServiceState) error {
called = true
return nil
})
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
callCtx := testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{
ServiceDeps: []string{pingUserServiceName},
})
if _, err := pp.UserClient.Ping(callCtx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
if !called {
t.Error("onPing not called")
}
}
func TestRPCFailure(t *gotesting.T) {
ctx := testcontext.WithCurrentEntity(context.Background(), &testcontext.CurrentEntity{})
req := &protocol.HandshakeRequest{NeedUserServices: true}
called := false
svc := newPingService(func(context.Context, *testing.ServiceState) error {
called = true
return errors.New("failure")
})
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
callCtx := testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{
ServiceDeps: []string{pingUserServiceName},
})
if _, err := pp.UserClient.Ping(callCtx, &empty.Empty{}); err == nil {
t.Error("Ping unexpectedly succeeded")
}
if !called {
t.Error("onPing not called")
}
}
func TestRPCPanic(t *gotesting.T) {
ctx := testcontext.WithCurrentEntity(context.Background(), &testcontext.CurrentEntity{})
req := &protocol.HandshakeRequest{NeedUserServices: true}
called := false
svc := newPingService(func(context.Context, *testing.ServiceState) error {
called = true
panic("Ping was called")
})
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
callCtx := testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{
ServiceDeps: []string{pingUserServiceName},
})
if _, err := pp.UserClient.Ping(callCtx, &empty.Empty{}); err == nil {
t.Error("Ping unexpectedly succeeded")
} else if !strings.Contains(err.Error(), "panic: Ping was called") {
t.Error("Ping error did not contain panic info: ", err)
}
if !called {
t.Error("onPing not called")
}
}
func TestRPCNotRequested(t *gotesting.T) {
ctx := testcontext.WithCurrentEntity(context.Background(), &testcontext.CurrentEntity{})
req := &protocol.HandshakeRequest{} // user-defined gRPC services not requested
called := false
svc := newPingService(func(context.Context, *testing.ServiceState) error {
called = true
return nil
})
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
callCtx := testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{
ServiceDeps: []string{pingUserServiceName},
})
if _, err := pp.UserClient.Ping(callCtx, &empty.Empty{}); err == nil {
t.Error("Ping unexpectedly succeeded")
}
if called {
t.Error("onPing unexpectedly called")
}
}
func TestRPCNoCurrentEntity(t *gotesting.T) {
ctx := testcontext.WithCurrentEntity(context.Background(), &testcontext.CurrentEntity{})
req := &protocol.HandshakeRequest{NeedUserServices: true}
called := false
svc := newPingService(func(context.Context, *testing.ServiceState) error {
called = true
return nil
})
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
if _, err := pp.UserClient.Ping(context.Background(), &empty.Empty{}); err == nil {
t.Error("Ping unexpectedly succeeded for a context missing CurrentEntity")
}
if called {
t.Error("onPing unexpectedly called")
}
}
func TestRPCRejectUndeclaredServices(t *gotesting.T) {
ctx := testcontext.WithCurrentEntity(context.Background(), &testcontext.CurrentEntity{})
req := &protocol.HandshakeRequest{NeedUserServices: true}
svc := newPingService(func(context.Context, *testing.ServiceState) error { return nil })
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
callCtx := testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{
ServiceDeps: []string{"foo.Bar"},
})
if _, err := pp.UserClient.Ping(callCtx, &empty.Empty{}); err == nil {
t.Error("Ping unexpectedly succeeded despite undeclared service")
}
}
func TestRPCForwardCurrentEntity(t *gotesting.T) {
expectedDeps := []string{"chrome", "android_p"}
ctx := testcontext.WithCurrentEntity(context.Background(), &testcontext.CurrentEntity{})
req := &protocol.HandshakeRequest{NeedUserServices: true}
called := false
var deps []string
var depsOK bool
svc := newPingService(func(ctx context.Context, s *testing.ServiceState) error {
called = true
deps, depsOK = testcontext.SoftwareDeps(ctx)
return nil
})
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
if _, err := pp.UserClient.Ping(ctx, &empty.Empty{}); err == nil {
t.Error("Ping unexpectedly succeeded for a context without CurrentEntity")
}
callCtx := testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{
ServiceDeps: []string{pingUserServiceName},
HasSoftwareDeps: true,
SoftwareDeps: expectedDeps,
})
if _, err := pp.UserClient.Ping(callCtx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
if !called {
t.Error("onPing not called")
} else if !depsOK {
t.Error("SoftwareDeps unavailable")
} else if !reflect.DeepEqual(deps, expectedDeps) {
t.Errorf("SoftwareDeps mismatch: got %v, want %v", deps, expectedDeps)
}
}
func TestRPCForwardLogs(t *gotesting.T) {
const exp = "hello"
ctx := context.Background()
sink, logs := newChannelSink()
ctx = logging.AttachLogger(ctx, logging.NewSinkLogger(logging.LevelDebug, false, sink))
ctx = testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{})
req := &protocol.HandshakeRequest{NeedUserServices: true}
called := false
svc := newPingService(func(ctx context.Context, s *testing.ServiceState) error {
called = true
logging.Debug(ctx, "world") // not delivered
logging.Info(ctx, exp)
return nil
})
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
callCtx := testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{
ServiceDeps: []string{pingUserServiceName},
})
if _, err := pp.UserClient.Ping(callCtx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
if !called {
t.Error("onPing not called")
}
select {
case msg := <-logs:
if msg != exp {
t.Errorf("Got log %q; want %q", msg, exp)
}
default:
t.Error("Logs unavailable immediately on RPC completion")
}
}
// TestRPCForwardLogsAsyncStress is a regression test for b/207577797.
// It exercises the scenario where a remote server emits a log in parallel to
// finishing a remote method call and/or the RPC connection is closed.
func TestRPCForwardLogsAsyncStress(t *gotesting.T) {
// n is number of attempts. n=1000 takes less than one second on modern
// machines.
const n = 1000
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
defer wg.Done()
ctx := context.Background()
ctx = testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{})
req := &protocol.HandshakeRequest{NeedUserServices: true}
svc := newPingService(func(ctx context.Context, s *testing.ServiceState) error {
logging.Info(ctx, "hello")
go logging.Info(ctx, "world") // emit asynchronously
return nil
})
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
callCtx := testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{
ServiceDeps: []string{pingUserServiceName},
})
if _, err := pp.UserClient.Ping(callCtx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
}()
}
wg.Wait()
}
func TestRPCForwardTiming(t *gotesting.T) {
const stageName = "hello"
ctx := context.Background()
ctx = testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{})
log := timing.NewLog()
ctx = timing.NewContext(ctx, log)
req := &protocol.HandshakeRequest{NeedUserServices: true}
called := false
svc := newPingService(func(ctx context.Context, s *testing.ServiceState) error {
called = true
_, st := timing.Start(ctx, stageName)
st.End()
return nil
})
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
callCtx := testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{
ServiceDeps: []string{pingUserServiceName},
})
if _, err := pp.UserClient.Ping(callCtx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
if !called {
t.Error("onPing not called")
}
if len(log.Root.Children) != 1 || log.Root.Children[0].Name != stageName {
b, err := json.Marshal(log)
if err != nil {
t.Fatal("Failed to marshal timing JSON: ", err)
}
t.Errorf("Unexpected timing log: got %s, want a single %q entry", string(b), stageName)
}
}
func TestRPCPullOutDir(t *gotesting.T) {
outDir := testutil.TempDir(t)
defer os.RemoveAll(outDir)
want := map[string]string{
"a.txt": "abc",
"dir/b.txt": "def",
}
ctx := context.Background()
ctx = testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{})
req := &protocol.HandshakeRequest{NeedUserServices: true}
svc := newPingService(func(ctx context.Context, s *testing.ServiceState) error {
od, ok := testcontext.OutDir(ctx)
if !ok {
return errors.New("OutDir unavailable")
}
if od == outDir {
return errors.Errorf("OutDir given to service must not be that on the host: %s", od)
}
st, err := os.Stat(od)
if err != nil {
return err
}
const mask = os.ModePerm | os.ModeSticky
if mode := st.Mode() & mask; mode != mask {
return errors.Errorf("wrong directory permission: got %o, want %o", mode, mask)
}
return testutil.WriteFiles(od, want)
})
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
callCtx := testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{
ServiceDeps: []string{pingUserServiceName},
OutDir: outDir,
})
if _, err := pp.UserClient.Ping(callCtx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
got, err := testutil.ReadFiles(outDir)
if err != nil {
t.Fatal("Failed to read output dir: ", err)
}
if diff := cmp.Diff(got, want); diff != "" {
t.Errorf("Directory contents mismatch (-got +want):\n%s", diff)
}
}
func TestRPCSetVars(t *gotesting.T) {
ctx := testcontext.WithCurrentEntity(context.Background(), &testcontext.CurrentEntity{})
key := "var1"
exp := "value1"
req := &protocol.HandshakeRequest{
NeedUserServices: true,
BundleInitParams: &protocol.BundleInitParams{
Vars: map[string]string{key: exp},
},
}
called := false
var value string
ok := false
svc := newPingService(func(ctx context.Context, s *testing.ServiceState) error {
called = true
value, ok = s.Var(key)
return nil
})
// Set service vars in service definition.
svc.Vars = []string{key}
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
callCtx := testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{
ServiceDeps: []string{pingUserServiceName},
})
if _, err := pp.UserClient.Ping(callCtx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
if !called {
t.Error("onPing not called")
}
if !ok || value != exp {
t.Errorf("Runtime var not set for key %q: got ok %t and value %q, want %q", key, ok, value, exp)
}
}
func TestRPCServiceScopedContext(t *gotesting.T) {
const exp = "hello"
ctx := context.Background()
sink, logs := newChannelSink()
ctx = logging.AttachLogger(ctx, logging.NewSinkLogger(logging.LevelDebug, false, sink))
ctx = testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{})
req := &protocol.HandshakeRequest{NeedUserServices: true}
called := false
svc := newPingService(func(ctx context.Context, s *testing.ServiceState) error {
called = true
logging.Debug(ctx, "world") // not delivered
logging.Info(s.ServiceContext(), exp)
return nil
})
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
callCtx := testcontext.WithCurrentEntity(ctx, &testcontext.CurrentEntity{
ServiceDeps: []string{pingUserServiceName},
})
if _, err := pp.UserClient.Ping(callCtx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
if !called {
t.Error("onPing not called")
}
if msg := <-logs; msg != exp {
t.Errorf("Got log %q; want %q", msg, exp)
}
}
func TestRPCExtraCoreServices(t *gotesting.T) {
ctx := context.Background()
req := &protocol.HandshakeRequest{NeedUserServices: false}
svc := newPingService(nil)
pp := newPingPair(ctx, t, req, svc)
defer pp.Close()
if _, err := pp.CoreClient.Ping(ctx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
}
func TestRPCOverExec(t *gotesting.T) {
ctx := context.Background()
// Create a loopback executable providing gRPC server.
dir, err := ioutil.TempDir("", "tast-unittest.")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
path := filepath.Join(dir, "rpc-server")
lo, err := fakeexec.CreateLoopback(path, func(_ []string, stdin io.Reader, stdout, _ io.WriteCloser) int {
if err := RunServer(stdin, stdout, nil, func(srv *grpc.Server, req *protocol.HandshakeRequest) error {
protocol.RegisterPingCoreServer(srv, &pingCoreServer{})
return nil
}); err != nil {
fmt.Fprintf(os.Stderr, "FATAL: %v\n", err)
return 1
}
return 0
})
if err != nil {
t.Fatal(err)
}
defer lo.Close()
// Connect to the server and try calling a method.
conn, err := DialExec(ctx, path, false, &protocol.HandshakeRequest{})
if err != nil {
t.Fatalf("DialExec failed: %v", err)
}
defer func() {
if err := conn.Close(); err != nil {
t.Errorf("Close failed: %v", err)
}
}()
cl := protocol.NewPingCoreClient(conn.Conn())
if _, err := cl.Ping(ctx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
}
func TestPanicOverExec(t *gotesting.T) {
ctx := context.Background()
// Create a loopback executable providing gRPC server.
dir, err := ioutil.TempDir("", "tast-unittest.")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
path := filepath.Join(dir, "rpc-server")
lo, err := fakeexec.CreateLoopback(path, func(_ []string, stdin io.Reader, stdout, _ io.WriteCloser) int {
if err := RunServer(stdin, stdout, nil, func(srv *grpc.Server, req *protocol.HandshakeRequest) error {
protocol.RegisterPingCoreServer(srv, &pingPanicServer{})
return nil
}); err != nil {
fmt.Fprintf(os.Stderr, "FATAL: %v\n", err)
return 1
}
return 0
})
if err != nil {
t.Fatal(err)
}
defer lo.Close()
// Connect to the server and try calling a method.
conn, err := DialExec(ctx, path, false, &protocol.HandshakeRequest{})
if err != nil {
t.Fatalf("DialExec failed: %v", err)
}
defer func() {
if err := conn.Close(); err != nil {
t.Errorf("Close failed: %v", err)
}
}()
cl := protocol.NewPingCoreClient(conn.Conn())
if _, err := cl.Ping(ctx, &empty.Empty{}); err == nil {
t.Error("Ping unexpectedly succeeded")
} else if !strings.Contains(err.Error(), "panic: pingPanicServer.Ping was called") {
t.Error("Ping error did not contain panic info: ", err)
}
}
type leakingPingServer struct{}
func (s *leakingPingServer) Ping(ctx context.Context, _ *empty.Empty) (*empty.Empty, error) {
// Intentionally leak a subprocess.
exec.Command("sleep", "60").Start()
return &empty.Empty{}, nil
}
var leakingMain = fakeexec.NewAuxMain("rpc_new_session_test", func(_ struct{}) {
if err := RunServer(os.Stdin, os.Stdout, nil, func(srv *grpc.Server, req *protocol.HandshakeRequest) error {
protocol.RegisterPingCoreServer(srv, &leakingPingServer{})
return nil
}); err != nil {
fmt.Fprintf(os.Stderr, "FATAL: %v\n", err)
os.Exit(1)
}
})
func TestRPCOverExecNewSession(t *gotesting.T) {
ctx := context.Background()
params, err := leakingMain.Params(struct{}{})
if err != nil {
t.Fatal(err)
}
restore := params.SetEnvs()
defer restore()
for _, newSession := range []bool{false, true} {
t.Run(strconv.FormatBool(newSession), func(t *gotesting.T) {
var subproc *process.Process
func() {
// Connect to the server and call a method.
conn, err := DialExec(ctx, params.Executable(), newSession, &protocol.HandshakeRequest{})
if err != nil {
t.Fatalf("DialExec failed: %v", err)
}
defer func() {
if err := conn.Close(); err != nil {
t.Errorf("Close failed: %v", err)
}
}()
// Call Ping. This will leak a subprocess.
cl := protocol.NewPingCoreClient(conn.Conn())
if _, err := cl.Ping(ctx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
// Find the leaked subprocess.
procs, err := process.Processes()
if err != nil {
t.Fatalf("Failed to enumerate processes: %v", err)
}
for _, proc := range procs {
ppid, err := proc.Ppid()
if err == nil && int(ppid) == conn.PID() {
subproc = proc
break
}
}
if subproc == nil {
t.Fatal("Failed to find a leaked subprocess")
}
}()
if newSession {
// Closing rpc.SSHClient should have killed the whole session.
// Wait some time to allow the process to exit.
if err := testingutil.Poll(context.Background(), func(context.Context) error {
if _, err := subproc.Status(); err != nil {
return nil
}
return errors.Errorf("process %d still exists", subproc.Pid)
}, &testingutil.PollOptions{Timeout: 10 * time.Second}); err != nil {
t.Fatalf("Failed to wait for a leaked subprocess to exit: %v", err)
}
} else {
// Leaked subprocess should be still running.
if err := subproc.Terminate(); err != nil {
t.Fatalf("Failed to kill the leaked subprocess: %v", err)
}
}
})
}
}
func TestRPCOverSSH(t *gotesting.T) {
ctx := context.Background()
// Start a fake SSH server providing gRPC server.
td := sshtest.NewTestData(func(req *sshtest.ExecReq) {
req.Start(true)
if err := RunServer(req, req, nil, func(srv *grpc.Server, req *protocol.HandshakeRequest) error {
protocol.RegisterPingCoreServer(srv, &pingCoreServer{})
return nil
}); err != nil {
fmt.Fprintf(req.Stderr(), "FATAL: %v\n", err)
req.End(1)
return
}
req.End(0)
})
defer td.Close()
sshConn, err := ssh.New(ctx, &ssh.Options{
Hostname: td.Srvs[0].Addr().String(),
KeyFile: td.UserKeyFile,
})
if err != nil {
t.Fatalf("Failed to connect to fake SSH server: %v", err)
}
defer sshConn.Close(ctx)
// Connect to the server and try calling a method.
conn, err := DialSSH(ctx, sshConn, "", &protocol.HandshakeRequest{}, false)
if err != nil {
t.Fatalf("DialSSH failed: %v", err)
}
defer func() {
if err := conn.Close(); err != nil {
t.Errorf("Close failed: %v", err)
}
}()
cl := protocol.NewPingCoreClient(conn.Conn())
if _, err := cl.Ping(ctx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
}
// Verifies that rpc calls fail if the context passed to DialSSH is cancelled.
func TestRPCOverSSHShortContext(t *gotesting.T) {
t.Skip("Test disabled because of flakiness b/262413993")
ctx := context.Background()
ctx = logging.AttachLoggerNoPropagation(ctx, loggingtest.NewLogger(t, logging.LevelInfo))
// Start a fake SSH server providing gRPC server.
td := sshtest.NewTestData(func(req *sshtest.ExecReq) {
req.Start(true)
if err := RunServer(req, req, nil, func(srv *grpc.Server, req *protocol.HandshakeRequest) error {
protocol.RegisterPingCoreServer(srv, &pingCoreServer{})
return nil
}); err != nil {
fmt.Fprintf(req.Stderr(), "FATAL: %v\n", err)
req.End(1)
return
}
req.End(0)
})
defer td.Close()
sshConn, err := ssh.New(ctx, &ssh.Options{
Hostname: td.Srvs[0].Addr().String(),
KeyFile: td.UserKeyFile,
})
if err != nil {
t.Fatalf("Failed to connect to fake SSH server: %v", err)
}
defer sshConn.Close(ctx)
// Connect to the server and try calling a method.
shortContext, shortCancel := context.WithCancel(ctx)
conn, err := DialSSH(shortContext, sshConn, "", &protocol.HandshakeRequest{}, false)
if err != nil {
t.Fatalf("DialSSH failed: %v", err)
}
// Should succeed
cl := protocol.NewPingCoreClient(conn.Conn())
t.Log("Ping")
if _, err := cl.Ping(ctx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
t.Log("Cancel shortContext")
shortCancel()
t.Log("Ping")
if _, err := cl.Ping(ctx, &empty.Empty{}); err != nil {
t.Error("Ping failed: ", err)
}
// Should fail
t.Log("Close conn")
if err := conn.Close(); err == nil {
t.Error("Close succeeded, but was expected to fail")
} else if !strings.Contains(err.Error(), "remote logging background routine failed") || !strings.Contains(err.Error(), "context canceled") {
t.Error("Close failed with wrong error: ", err)
}
if err := sshConn.Ping(ctx, time.Second); err != nil {
t.Error("SSH ping failed: ", err)
}
}
func TestPanicOverSSH(t *gotesting.T) {
ctx := context.Background()
// Start a fake SSH server providing gRPC server.
td := sshtest.NewTestData(func(req *sshtest.ExecReq) {
req.Start(true)
if err := RunServer(req, req, nil, func(srv *grpc.Server, req *protocol.HandshakeRequest) error {
protocol.RegisterPingCoreServer(srv, &pingPanicServer{})
return nil
}); err != nil {
fmt.Fprintf(req.Stderr(), "FATAL: %v\n", err)
req.End(1)
return
}
req.End(0)
})
defer td.Close()
sshConn, err := ssh.New(ctx, &ssh.Options{
Hostname: td.Srvs[0].Addr().String(),
KeyFile: td.UserKeyFile,
})
if err != nil {
t.Fatalf("Failed to connect to fake SSH server: %v", err)
}
defer sshConn.Close(ctx)
// Connect to the server and try calling a method.
conn, err := DialSSH(ctx, sshConn, "", &protocol.HandshakeRequest{}, false)
if err != nil {
t.Fatalf("DialSSH failed: %v", err)
}
defer func() {
if err := conn.Close(); err != nil {
t.Errorf("Close failed: %v", err)
}
}()
cl := protocol.NewPingCoreClient(conn.Conn())
if _, err := cl.Ping(ctx, &empty.Empty{}); err == nil {
t.Error("Ping unexpectedly succeeded")
} else if !strings.Contains(err.Error(), "panic: pingPanicServer.Ping was called") {
t.Error("Ping error did not contain panic info: ", err)
}
}
const (
textReady = "ready"
textFinished = "finished"
)
type subprocessServer struct {
path string
}
func (s *subprocessServer) Ping(ctx context.Context, _ *empty.Empty) (*empty.Empty, error) {
if err := ctx.Err(); err != nil {
return nil, errors.Wrap(err, "context already canceled on entering method")
}
// Notify the parent process that we're in the middle of a method call.
ioutil.WriteFile(s.path, []byte(textReady), 0666)
// Wait for the context to be canceled.
<-ctx.Done()
// Notify the parent process that we're finishing the method call.
ioutil.WriteFile(s.path, []byte(textFinished), 0666)
return &empty.Empty{}, nil
}
var stdioMain = fakeexec.NewAuxMain("rpc_stdio_test", func(path string) {
RunServer(os.Stdin, os.Stdout, nil, func(s *grpc.Server, req *protocol.HandshakeRequest) error {
protocol.RegisterPingCoreServer(s, &subprocessServer{path})
return nil
})
})
// runStdioTestServer starts a subprocess serving subprocessServer and
// starts an asynchronous call of its Ping method.
func runStdioTestServer(t *gotesting.T) (cmd *exec.Cmd, stdin io.WriteCloser, stdout io.ReadCloser, waitReady, waitFinish func(t *gotesting.T)) {
ctx := context.Background()
// Create a temporary file. Is is initially empty, but a subprocess
// writes some data to it later.
f, err := ioutil.TempFile("", "tast-unittest.")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
f.Close()
os.Remove(f.Name())
})
// Run a fake subprocess serving subprocessServer.
params, err := stdioMain.Params(f.Name())
if err != nil {
t.Fatal(err)
}
cmd = exec.Command(params.Executable())
cmd.Env = append(os.Environ(), params.Envs()...)
cmd.Stderr = os.Stderr
stdin, err = cmd.StdinPipe()
if err != nil {
t.Fatal(err)
}
stdout, err = cmd.StdoutPipe()
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
cmd.Process.Kill()
cmd.Wait()
})
conn, err := NewClient(ctx, stdout, stdin, &protocol.HandshakeRequest{})
if err != nil {
t.Fatalf("Failed to establish gRPC connection to subprocess: %v", err)
}
t.Cleanup(func() {
conn.Close()
})
// Make an RPC call on a goroutine.
go func() {
cl := protocol.NewPingCoreClient(conn.Conn())
cl.Ping(ctx, &empty.Empty{})
}()
// waitText waits until f's content becomes the specified one.
waitText := func(t *gotesting.T, want string) {
if err := testingutil.Poll(ctx, func(ctx context.Context) error {
b, err := ioutil.ReadFile(f.Name())
if err != nil {
return testingutil.PollBreak(err)
}
got := string(b)
if got != want {
return errors.Errorf("content mismatch: got %q, want %q", got, want)
}
return nil
}, &testingutil.PollOptions{Timeout: 10 * time.Second}); err != nil {
t.Fatalf("Failed to wait for subprocess write: %v", err)
}
}
// waitReady waits for the subprocess to enter the gRPC method.
waitReady = func(t *gotesting.T) { waitText(t, textReady) }
// waitFinish wait for the subprocess to finish the gRPC method call.
waitFinish = func(t *gotesting.T) { waitText(t, textFinished) }
return cmd, stdin, stdout, waitReady, waitFinish
}
func TestRPCOverStdioSIGPIPE(t *gotesting.T) {
_, stdin, stdout, waitReady, waitSuccess := runStdioTestServer(t)
waitReady(t)
// Close stdout of the subprocess. If the subprocess doesn't install
// SIGPIPE handlers, writing data to stdout will cause termination.
stdout.Close()
// Close stdin to stop the gRPC server.
stdin.Close()
waitSuccess(t)
}
func TestRPCOverStdioSIGINT(t *gotesting.T) {
cmd, _, _, waitReady, waitSuccess := runStdioTestServer(t)
waitReady(t)
// Send SIGINT to the subprocess. If the subprocess doesn't install
// SIGINT handlers it will terminate immediately.
cmd.Process.Signal(unix.SIGINT)
waitSuccess(t)
}