blob: 029216fd5edea07047463188d7db892d544fa71e [file] [log] [blame] [edit]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package sshtest
import (
"crypto/rsa"
"errors"
"os"
)
const (
defaultKeyBits = 1024
)
// MustGenerateKeys can be called from a test file's init function to generate
// 1024-bit user and host keys. Panics on error.
func MustGenerateKeys() (userKey, hostKey *rsa.PrivateKey) {
var err error
if userKey, hostKey, err = GenerateKeys(defaultKeyBits); err != nil {
panic(err)
}
return userKey, hostKey
}
// TestData contains common data that can be used by tests that interact with an SSHServer.
type TestData struct { // NOLINT
Srvs []*SSHServer
UserKeyFile string
}
// NewTestData initializes and returns a TestData struct. Panics on error.
func NewTestData(handlers ...ExecHandler) *TestData {
if len(handlers) == 0 {
panic(errors.New("no handler is specfied"))
}
userKey, hostKey := MustGenerateKeys()
var servers []*SSHServer
for _, handler := range handlers {
srv, err := NewSSHServer(&userKey.PublicKey, hostKey, handler)
if err != nil {
for _, srv := range servers {
srv.Close()
}
panic(err)
}
servers = append(servers, srv)
}
userKeyFile, err := WriteKey(userKey)
if err != nil {
for _, srv := range servers {
srv.Close()
}
panic(err)
}
return &TestData{
Srvs: servers,
UserKeyFile: userKeyFile,
}
}
// Close stops the SSHServer and deletes the user key file.
func (d *TestData) Close() {
for _, s := range d.Srvs {
s.Close()
}
os.Remove(d.UserKeyFile)
}