blob: b9de8dff1c064449e96554731d0890a2fa4cb293 [file] [log] [blame]
// Copyright 2022 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package ssh
import (
_ "embed"
"errors"
"fmt"
"log"
"os"
"os/user"
"path/filepath"
"golang.org/x/crypto/ssh"
)
// SSH RSA private key embedded as bytes.
// https://chromium.googlesource.com/chromiumos/chromite/+/main/ssh_keys/testing_rsa
//
//go:embed testing_rsa
var testingRSA []byte
const (
testingRSAFileName = "testing_rsa"
partnerTestingRSAFileName = "partner_testing_rsa"
)
type KeyChain struct {
parsedKeys []ssh.Signer
// Temporary directory holding the private keys
dir string
// List of private keys
files []string
}
// getPartnerTestingRSA attempts to read the partner_testing_rsa from the
// user's ~/.ssh directory.
func getPartnerTestingRSA() ([]byte, error) {
u, err := user.Current()
if err != nil {
return nil, err
}
if u.HomeDir == "" {
return nil, errors.New("cannot determine home directory")
}
partnerTestingRSAFile := filepath.Join(u.HomeDir, ".ssh", partnerTestingRSAFileName)
return os.ReadFile(partnerTestingRSAFile)
}
// NewKeyChain creates a KeyChain holding bundled SSH keys.
//
// KeyChain.Delete() must be called to clean up resources.
func NewKeyChain() (*KeyChain, error) {
tempDir, err := os.MkdirTemp("", "fflash-keychain-*")
if err != nil {
return nil, err
}
kc := &KeyChain{dir: tempDir}
if err := kc.addKeyBytes(testingRSAFileName, testingRSA); err != nil {
kc.Delete()
return nil, err
}
partnerTestingRSA, err := getPartnerTestingRSA()
if err != nil {
log.Printf("Not using partner key: cannot read ~/.ssh/%s: %s", partnerTestingRSAFileName, err)
} else {
if err := kc.addKeyBytes(partnerTestingRSAFileName, partnerTestingRSA); err != nil {
kc.Delete()
return nil, err
}
}
return kc, nil
}
func (kc *KeyChain) addKeyBytes(name string, b []byte) error {
key, err := ssh.ParsePrivateKey(b)
if err != nil {
return fmt.Errorf("Cannot parse private key %q: %v", name, err)
}
kc.parsedKeys = append(kc.parsedKeys, key)
path := filepath.Join(kc.dir, name)
if err := os.WriteFile(path, b, 0400); err != nil {
return err
}
kc.files = append(kc.files, path)
return nil
}
// SSHAuthMethod returns a ssh.AuthMethod using the keys in the KeyChain.
func (kc *KeyChain) SSHAuthMethod() ssh.AuthMethod {
return ssh.PublicKeys(kc.parsedKeys...)
}
// SSHCommandOptions returns ssh command line options that uses the keys.
func (kc *KeyChain) SSHCommandOptions() []string {
var opts []string
for _, keyPath := range kc.files {
opts = append(opts, fmt.Sprintf("-oIdentityFile=%s", keyPath))
}
return opts
}
// Delete cleans up the key chain
func (kc *KeyChain) Delete() error {
return os.RemoveAll(kc.dir)
}