tast: Enforce timeouts when reading control messages.

Make the tast executable enforce timeouts when reading
control messages sent by the local_tests (or remote_tests,
although it runs on the same machine) executable.

By default, a one-minute timeout is used, but this is
lengthened based on per-test timeouts while tests are
running.

BUG=chromium:778389
TEST=added unit tests; also pulled the plug on a device
     mid-test and checked that the tast command timed out
     after two minutes

Change-Id: Ia60c06d408692e3d5f9bfabec7059002d581fdfd
Reviewed-on: https://chromium-review.googlesource.com/757163
Commit-Ready: Dan Erat <derat@chromium.org>
Tested-by: Dan Erat <derat@chromium.org>
Reviewed-by: Jason Clinton <jclinton@chromium.org>
diff --git a/src/chromiumos/tast/cmd/run/config.go b/src/chromiumos/tast/cmd/run/config.go
index f3524ae..6bee6bb 100644
--- a/src/chromiumos/tast/cmd/run/config.go
+++ b/src/chromiumos/tast/cmd/run/config.go
@@ -8,6 +8,7 @@
 import (
 	"flag"
 	"path/filepath"
+	"time"
 
 	"chromiumos/tast/cmd/build"
 	"chromiumos/tast/cmd/logging"
@@ -26,6 +27,8 @@
 	Target   string         // target for testing, in the form "[<user>@]host[:<port>]"
 	Patterns []string       // patterns specifying tests to run
 	ResDir   string         // directory where test results should be written
+
+	msgTimeout time.Duration // timeout for reading control messages; default used if zero
 }
 
 // SetFlags adds common run-related flags to f that store values in Config.
diff --git a/src/chromiumos/tast/cmd/run/results.go b/src/chromiumos/tast/cmd/run/results.go
index b82864a..0586102 100644
--- a/src/chromiumos/tast/cmd/run/results.go
+++ b/src/chromiumos/tast/cmd/run/results.go
@@ -30,6 +30,8 @@
 	testLogFilename = "log.txt"      // file in test's dir containing its logs
 
 	testOutputTimeFmt = "15:04:05.000" // format for timestamps attached to test output
+
+	defaultMsgTimeout = time.Minute // default timeout for reading next control message
 )
 
 // testResult contains the results from a single test.
@@ -41,14 +43,15 @@
 	// Errors contains errors encountered while running the test.
 	// If it is empty, the test passed.
 	Errors []testing.Error `json:"errors"`
-	// Start is the time at which the test started.
+	// Start is the time at which the test started (as reported by the test binary).
 	Start time.Time `json:"start"`
-	// End is the time at which the test completed.
+	// End is the time at which the test completed (as reported by the test binary).
 	End time.Time `json:"end"`
 	// OutDir is the directory into which test output is stored.
 	OutDir string `json:"outDir"`
 
-	logFile *os.File // test's log file
+	testStartMsgTime time.Time // time at which TestStart control message was received
+	logFile          *os.File  // test's log file
 }
 
 // copyAndRemoveFunc copies src on a DUT to dst on the local machine and then
@@ -141,9 +144,10 @@
 	}
 
 	r.res = &testResult{
-		Test:   msg.Test,
-		Start:  msg.Time,
-		OutDir: r.getTestOutputDir(msg.Test.Name),
+		Test:             msg.Test,
+		Start:            msg.Time,
+		OutDir:           r.getTestOutputDir(msg.Test.Name),
+		testStartMsgTime: time.Now(),
 	}
 
 	var err error
@@ -299,6 +303,36 @@
 	return nil
 }
 
+// nextMessageTimeout calculates the maximum amount of time to wait for the next
+// control message from the test executable.
+func (r *resultsHandler) nextMessageTimeout(now time.Time) time.Duration {
+	timeout := defaultMsgTimeout
+	if r.cfg.msgTimeout > 0 {
+		timeout = r.cfg.msgTimeout
+	}
+
+	// If we're in the middle of a test, add its timeout.
+	if r.res != nil {
+		elapsed := now.Sub(r.res.testStartMsgTime)
+		if elapsed < r.res.Timeout {
+			timeout += r.res.Timeout - elapsed
+		}
+	}
+
+	// Now cap the timeout to the context's deadline, if any.
+	ctxDeadline, ok := r.ctx.Deadline()
+	if !ok {
+		return timeout
+	}
+	if now.After(ctxDeadline) {
+		return time.Duration(0)
+	}
+	if ctxTimeout := ctxDeadline.Sub(now); ctxTimeout < timeout {
+		return ctxTimeout
+	}
+	return timeout
+}
+
 // handleMessage handles generic control messages from test executables.
 func (r *resultsHandler) handleMessage(msg interface{}) error {
 	switch v := msg.(type) {
@@ -323,6 +357,44 @@
 	}
 }
 
+// processMessages processes control messages and errors supplied by mch and ech.
+func (r *resultsHandler) processMessages(mch chan interface{}, ech chan error) error {
+	for {
+		timeout := r.nextMessageTimeout(time.Now())
+		select {
+		case msg := <-mch:
+			if msg == nil {
+				// If the channel is closed, we'll read the zero value.
+				return nil
+			}
+			if err := r.handleMessage(msg); err != nil {
+				return err
+			}
+		case err := <-ech:
+			return err
+		case <-time.After(timeout):
+			return fmt.Errorf("timed out after waiting %v for next message", timeout)
+		}
+	}
+}
+
+// readMessages reads serialized control messages from r and passes them
+// via mch. If an error is encountered, it is passed via ech and no more
+// reads are performed. Channels are closed before returning.
+func readMessages(r io.Reader, mch chan interface{}, ech chan error) {
+	mr := control.NewMessageReader(r)
+	for mr.More() {
+		msg, err := mr.ReadMessage()
+		if err != nil {
+			ech <- err
+			break
+		}
+		mch <- msg
+	}
+	close(mch)
+	close(ech)
+}
+
 // readTestOutput reads test output from r and writes the test results to cfg.ResDir.
 func readTestOutput(ctx context.Context, cfg *Config, r io.Reader, crf copyAndRemoveFunc) error {
 	rh := resultsHandler{
@@ -333,15 +405,15 @@
 	}
 	defer rh.close()
 
-	mr := control.NewMessageReader(r)
-	for mr.More() {
-		msg, err := mr.ReadMessage()
-		if err != nil {
-			return err
-		}
-		if err = rh.handleMessage(msg); err != nil {
-			return err
-		}
+	mch := make(chan interface{})
+	ech := make(chan error)
+	go readMessages(r, mch, ech)
+
+	if err := rh.processMessages(mch, ech); err != nil {
+		return err
 	}
+
+	// TODO(derat): Check that RunStart and RunEnd messages were received and that the
+	// number of TestStart/TestEnd pairs matched the number specified in RunStart.
 	return nil
 }
diff --git a/src/chromiumos/tast/cmd/run/results_test.go b/src/chromiumos/tast/cmd/run/results_test.go
index 43f13b7..0bdd185 100644
--- a/src/chromiumos/tast/cmd/run/results_test.go
+++ b/src/chromiumos/tast/cmd/run/results_test.go
@@ -8,6 +8,7 @@
 	"bytes"
 	"context"
 	"encoding/json"
+	"io"
 	"os"
 	"path/filepath"
 	gotesting "testing"
@@ -19,6 +20,9 @@
 	"chromiumos/tast/common/testutil"
 )
 
+// noOpCopyAndRemove can be passed to readTestOutput by tests.
+func noOpCopyAndRemove(src, dst string) error { return nil }
+
 func TestReadTestOutput(t *gotesting.T) {
 	const (
 		test1Name    = "foo.FirstTest"
@@ -118,3 +122,89 @@
 
 	// TODO(derat): Check more output, including run errors.
 }
+
+func TestReadTestOutputTimeout(t *gotesting.T) {
+	tempDir := testutil.TempDir(t, "results_test.")
+	defer os.RemoveAll(tempDir)
+
+	// Create a pipe, but don't write to it or close it during the test.
+	// readTestOutput should time out and report an error.
+	pr, pw := io.Pipe()
+	defer pw.Close()
+
+	cfg := Config{
+		Logger:     logging.NewSimple(&bytes.Buffer{}, 0, false),
+		ResDir:     tempDir,
+		msgTimeout: time.Millisecond,
+	}
+	if err := readTestOutput(context.Background(), &cfg, pr, noOpCopyAndRemove); err == nil {
+		t.Error("readTestOutput didn't return error for timeout")
+	}
+}
+
+func TestNextMessageTimeout(t *gotesting.T) {
+	now := time.Unix(60, 0)
+
+	for _, tc := range []struct {
+		now         time.Time
+		msgTimeout  time.Duration
+		ctxTimeout  time.Duration
+		testStart   time.Time
+		testTimeout time.Duration
+		exp         time.Duration
+	}{
+		{
+			// Outside a test, and without a custom or context timeout, use the default.
+			exp: defaultMsgTimeout,
+		},
+		{
+			// If a message timeout is supplied, use it instead of default.
+			msgTimeout: 5 * time.Second,
+			exp:        5 * time.Second,
+		},
+		{
+			// Mid-test, use the test's remaining time plus the normal message timeout.
+			msgTimeout:  10 * time.Second,
+			testStart:   now.Add(-1 * time.Second),
+			testTimeout: 5 * time.Second,
+			exp:         14 * time.Second,
+		},
+		{
+			// A context timeout should cap whatever timeout would be used otherwise.
+			msgTimeout: 20 * time.Second,
+			ctxTimeout: 11 * time.Second,
+			exp:        11 * time.Second,
+		},
+	} {
+		ctx := context.Background()
+		var cancel context.CancelFunc
+		if tc.ctxTimeout != 0 {
+			ctx, cancel = context.WithDeadline(ctx, now.Add(tc.ctxTimeout))
+		}
+
+		h := resultsHandler{
+			ctx: ctx,
+			cfg: &Config{msgTimeout: tc.msgTimeout},
+		}
+		if !tc.testStart.IsZero() {
+			h.res = &testResult{
+				Test:             testing.Test{Timeout: tc.testTimeout},
+				testStartMsgTime: tc.testStart,
+			}
+		}
+
+		// Avoid printing ugly negative numbers for unset testStart fields.
+		var testStartUnix int64
+		if !tc.testStart.IsZero() {
+			testStartUnix = tc.testStart.Unix()
+		}
+		if act := h.nextMessageTimeout(now); act != tc.exp {
+			t.Errorf("nextMessageTimeout(%v) (msgTimeout=%v, ctxTimeout=%v testStart=%v, testTimeout=%v) = %v; want %v",
+				now.Unix(), tc.msgTimeout, tc.ctxTimeout, testStartUnix, tc.testTimeout, act, tc.exp)
+		}
+
+		if cancel != nil {
+			cancel()
+		}
+	}
+}