blob: e2769b33c5523e5e89a550dfd677cac04b3a361d [file] [log] [blame] [edit]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package sshtest
import (
"context"
"crypto/rsa"
"fmt"
"os"
"strings"
"testing"
"time"
"go.chromium.org/infra/cros/servo/shutil"
"go.chromium.org/infra/cros/servo/ssh"
)
// ConnectToServer establishes a connection to srv using key.
// base is used as a base set of options.
func ConnectToServer(ctx context.Context, srv *SSHServer, key *rsa.PrivateKey, base *ssh.Options) (*ssh.Conn, error) {
keyFile, err := WriteKey(key)
if err != nil {
return nil, err
}
defer os.Remove(keyFile)
o := *base
o.KeyFile = keyFile
if err = ssh.ParseTarget(srv.Addr().String(), &o); err != nil {
return nil, err
}
s, err := ssh.New(ctx, &o)
if err != nil {
return nil, err
}
return s, nil
}
// TimeoutType describes different types of timeouts that can be simulated during SSH "exec" requests.
type TimeoutType int
const (
// NoTimeout indicates that TestData.Ctx shouldn't be canceled.
NoTimeout TimeoutType = iota
// StartTimeout indicates that TestData.Ctx should be canceled before the command starts.
StartTimeout
// EndTimeout indicates that TestData.Ctx should be canceled after the command runs but before its status is returned.
EndTimeout
)
// TestDataConn wraps data common to all tests.
// Whereas TastData only manages SSHServer it additionally owns connection to the server.
type TestDataConn struct {
Srv *SSHServer // local SSH server
// Hst is a connection to srv.
Hst *ssh.Conn
// Ctx is used for performaing operations using Hst.
Ctx context.Context
// Cancel cancels Ctx to simulate a timeout.
Cancel func()
// ExecTimeout directs how "exec" requests should time out.
ExecTimeout TimeoutType
}
// NewTestDataConn sets up local SSH server and connection to it, and
// returns them together as a TestDataConn struct.
// Caller must call Close after use.
func NewTestDataConn(t *testing.T) *TestDataConn {
td := &TestDataConn{}
td.Ctx, td.Cancel = context.WithCancel(context.Background())
userKey, hostKey := MustGenerateKeys()
var err error
if td.Srv, err = NewSSHServer(&userKey.PublicKey, hostKey, td.handleExec); err != nil {
t.Fatal(err)
}
if td.Hst, err = ConnectToServer(td.Ctx, td.Srv, userKey, &ssh.Options{}); err != nil {
td.Srv.Close()
t.Fatal(err)
}
return td
}
// Close releases resources associated with td.
func (td *TestDataConn) Close() {
_ = td.Srv.Close()
_ = td.Hst.Close(td.Ctx)
td.Cancel()
}
// handleExec handles an SSH "exec" request sent to td.Srv by executing the requested command.
func (td *TestDataConn) handleExec(req *ExecReq) {
// PutFiles sends multiple "exec" requests.
// Ignore its initial "sha1sum" so we can hang during the tar command instead.
ignoreTimeout := strings.HasPrefix(req.Cmd, "sha1sum ")
// If a timeout was requested, cancel the context and then sleep for an arbitrary-but-long
// amount of time to make sure that the client sees the expired context before the command
// actually runs.
if td.ExecTimeout == StartTimeout && !ignoreTimeout {
td.Cancel()
time.Sleep(time.Minute)
}
req.Start(true)
var status int
switch req.Cmd {
case shellCmd("", []string{"long_sleep"}):
time.Sleep(time.Hour)
default:
status = req.RunRealCmd()
}
if td.ExecTimeout == EndTimeout && !ignoreTimeout {
td.Cancel()
time.Sleep(time.Minute)
}
req.End(status)
}
// shellCmd builds a shell command string to execute a process with exec.
// It's copied from ssh/platform.go. TODO(oka): consider refactoring if duplication becomes bigger.
func shellCmd(dir string, args []string) string {
cmd := "exec " + shutil.EscapeSlice(args)
if dir != "" {
// Return 125 (chosen arbitrarily) if dir does not exist.
// TODO(nya): Consider handling the directory error more gracefully.
cmd = fmt.Sprintf("cd %s > /dev/null 2>&1 || exit 125; %s", shutil.Escape(dir), cmd)
}
return cmd
}