blob: 558fc515846f080bca13949a28249c759b210e3a [file] [edit]
package ssh_config
import (
"bytes"
"log"
"os"
"path/filepath"
"strings"
"testing"
)
func loadFile(t *testing.T, filename string) []byte {
t.Helper()
data, err := os.ReadFile(filename)
if err != nil {
t.Fatal(err)
}
return data
}
var files = []string{
"testdata/config1",
"testdata/config2",
"testdata/eol-comments",
}
func TestDecode(t *testing.T) {
for _, filename := range files {
data := loadFile(t, filename)
cfg, err := Decode(bytes.NewReader(data))
if err != nil {
t.Fatal(err)
}
out := cfg.String()
if out != string(data) {
t.Errorf("%s out != data: got:\n%s\nwant:\n%s\n", filename, out, string(data))
}
}
}
func testConfigFinder(filename string) func() string {
return func() string { return filename }
}
func nullConfigFinder() string {
return ""
}
func TestGet(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config1"),
}
val := us.Get("wap", "User")
if val != "root" {
t.Errorf("expected to find User root, got %q", val)
}
}
func TestGetWithDefault(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config1"),
}
val, err := us.GetStrict("wap", "PasswordAuthentication")
if err != nil {
t.Fatalf("expected nil err, got %v", err)
}
if val != "yes" {
t.Errorf("expected to get PasswordAuthentication yes, got %q", val)
}
}
func TestGetAllWithDefault(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config1"),
}
val, err := us.GetAllStrict("wap", "PasswordAuthentication")
if err != nil {
t.Fatalf("expected nil err, got %v", err)
}
if len(val) != 1 || val[0] != "yes" {
t.Errorf("expected to get PasswordAuthentication yes, got %q", val)
}
}
func TestGetIdentities(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/identities"),
}
val, err := us.GetAllStrict("hasidentity", "IdentityFile")
if err != nil {
t.Errorf("expected nil err, got %v", err)
}
if len(val) != 1 || val[0] != "file1" {
t.Errorf(`expected ["file1"], got %v`, val)
}
val, err = us.GetAllStrict("has2identity", "IdentityFile")
if err != nil {
t.Errorf("expected nil err, got %v", err)
}
if len(val) != 2 || val[0] != "f1" || val[1] != "f2" {
t.Errorf(`expected [\"f1\", \"f2\"], got %v`, val)
}
val, err = us.GetAllStrict("randomhost", "IdentityFile")
if err != nil {
t.Errorf("expected nil err, got %v", err)
}
if len(val) != len(defaultIdentityFiles) {
// TODO: return the right values here.
log.Printf("expected defaults, got %v", val)
} else {
for i, v := range defaultIdentityFiles {
if val[i] != v {
t.Errorf("invalid %d in val, expected %s got %s", i, v, val[i])
}
}
}
// "protocol1" host sets Protocol 1, but Protocol is ignored in modern
// OpenSSH (only SSH2 exists). No IdentityFile is set for this host, so
// the result is empty (IdentityFile has no single default value).
val, err = us.GetAllStrict("protocol1", "IdentityFile")
if err != nil {
t.Errorf("expected nil err, got %v", err)
}
if len(val) != 0 {
t.Errorf("expected [], got %v", val)
}
}
func TestGetQuotedValues(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/quoted-identities"),
}
val, err := us.GetStrict("hasquotedidentity", "IdentityFile")
if err != nil {
t.Fatal(err)
}
want := "/Users/testuser/.ssh/quoted_key"
if val != want {
t.Errorf("IdentityFile with quotes: got %q, want %q", val, want)
}
val, err = us.GetStrict("hasquotedhostname", "HostName")
if err != nil {
t.Fatal(err)
}
want = "example.com"
if val != want {
t.Errorf("HostName with quotes: got %q, want %q", val, want)
}
val, err = us.GetStrict("hasunquotedidentity", "IdentityFile")
if err != nil {
t.Fatal(err)
}
want = "/Users/testuser/.ssh/unquoted_key"
if val != want {
t.Errorf("IdentityFile without quotes: got %q, want %q", val, want)
}
// Verify roundtripping preserves quotes in the output
f, err := os.Open("testdata/quoted-identities")
if err != nil {
t.Fatal(err)
}
defer f.Close()
cfg, err := Decode(f)
if err != nil {
t.Fatal(err)
}
out, err := cfg.MarshalText()
if err != nil {
t.Fatal(err)
}
original, err := os.ReadFile("testdata/quoted-identities")
if err != nil {
t.Fatal(err)
}
if string(out) != string(original) {
t.Errorf("roundtrip mismatch:\ngot:\n%s\nwant:\n%s", out, original)
}
}
func TestGetInvalidPort(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/invalid-port"),
}
val, err := us.GetStrict("test.test", "Port")
if err == nil {
t.Fatalf("expected non-nil err, got nil")
}
if val != "" {
t.Errorf("expected to get '' for val, got %q", val)
}
if err.Error() != `ssh_config: strconv.ParseUint: parsing "notanumber": invalid syntax` {
t.Errorf("wrong error: got %v", err)
}
}
func TestGetNotFoundNoDefault(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config1"),
}
val, err := us.GetStrict("wap", "CanonicalDomains")
if err != nil {
t.Fatalf("expected nil err, got %v", err)
}
if val != "" {
t.Errorf("expected to get CanonicalDomains '', got %q", val)
}
}
func TestGetAllNotFoundNoDefault(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config1"),
}
val, err := us.GetAllStrict("wap", "CanonicalDomains")
if err != nil {
t.Fatalf("expected nil err, got %v", err)
}
if len(val) != 0 {
t.Errorf("expected to get CanonicalDomains '', got %q", val)
}
}
func TestGetWildcard(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config3"),
}
val := us.Get("bastion.stage.i.us.example.net", "Port")
if val != "22" {
t.Errorf("expected to find Port 22, got %q", val)
}
val = us.Get("bastion.net", "Port")
if val != "25" {
t.Errorf("expected to find Port 24, got %q", val)
}
val = us.Get("10.2.3.4", "Port")
if val != "23" {
t.Errorf("expected to find Port 23, got %q", val)
}
val = us.Get("101.2.3.4", "Port")
if val != "25" {
t.Errorf("expected to find Port 24, got %q", val)
}
val = us.Get("20.20.20.4", "Port")
if val != "24" {
t.Errorf("expected to find Port 24, got %q", val)
}
val = us.Get("20.20.20.20", "Port")
if val != "25" {
t.Errorf("expected to find Port 25, got %q", val)
}
}
func TestGetExtraSpaces(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/extraspace"),
}
val := us.Get("test.test", "Port")
if val != "1234" {
t.Errorf("expected to find Port 1234, got %q", val)
}
}
func TestGetCaseInsensitive(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config1"),
}
val := us.Get("wap", "uSER")
if val != "root" {
t.Errorf("expected to find User root, got %q", val)
}
}
func TestGetEmpty(t *testing.T) {
us := &UserSettings{
userConfigFinder: nullConfigFinder,
systemConfigFinder: nullConfigFinder,
}
val, err := us.GetStrict("wap", "User")
if err != nil {
t.Errorf("expected nil error, got %v", err)
}
if val != "" {
t.Errorf("expected to get empty string, got %q", val)
}
}
func TestGetEqsign(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/eqsign"),
}
val := us.Get("test.test", "Port")
if val != "1234" {
t.Errorf("expected to find Port 1234, got %q", val)
}
val = us.Get("test.test", "Port2")
if val != "5678" {
t.Errorf("expected to find Port2 5678, got %q", val)
}
}
var includeFile = []byte(`
# This host should not exist, so we can use it for test purposes / it won't
# interfere with any other configurations.
Host kevinburke.ssh_config.test.example.com
Port 4567
`)
func TestInclude(t *testing.T) {
if testing.Short() {
t.Skip("skipping fs write in short mode")
}
testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-test-file")
err := os.WriteFile(testPath, includeFile, 0644)
if err != nil {
t.Skipf("couldn't write SSH config file: %v", err.Error())
}
defer os.Remove(testPath)
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/include"),
}
val := us.Get("kevinburke.ssh_config.test.example.com", "Port")
if val != "4567" {
t.Errorf("expected to find Port=4567 in included file, got %q", val)
}
}
func TestIncludeSystem(t *testing.T) {
if testing.Short() {
t.Skip("skipping fs write in short mode")
}
testPath := filepath.Join("/", "etc", "ssh", "kevinburke-ssh-config-test-file")
err := os.WriteFile(testPath, includeFile, 0644)
if err != nil {
t.Skipf("couldn't write SSH config file: %v", err.Error())
}
defer os.Remove(testPath)
us := &UserSettings{
systemConfigFinder: testConfigFinder("testdata/include"),
}
val := us.Get("kevinburke.ssh_config.test.example.com", "Port")
if val != "4567" {
t.Errorf("expected to find Port=4567 in included file, got %q", val)
}
}
var recursiveIncludeFile = []byte(`
Host kevinburke.ssh_config.test.example.com
Include kevinburke-ssh-config-recursive-include
`)
func TestIncludeRecursive(t *testing.T) {
if testing.Short() {
t.Skip("skipping fs write in short mode")
}
testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-recursive-include")
err := os.WriteFile(testPath, recursiveIncludeFile, 0644)
if err != nil {
t.Skipf("couldn't write SSH config file: %v", err.Error())
}
defer os.Remove(testPath)
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/include-recursive"),
}
val, err := us.GetStrict("kevinburke.ssh_config.test.example.com", "Port")
if err != ErrDepthExceeded {
t.Errorf("Recursive include: expected ErrDepthExceeded, got %v", err)
}
if val != "" {
t.Errorf("non-empty string value %s", val)
}
}
func TestIncludeString(t *testing.T) {
if testing.Short() {
t.Skip("skipping fs write in short mode")
}
data, err := os.ReadFile("testdata/include")
if err != nil {
log.Fatal(err)
}
c, err := Decode(bytes.NewReader(data))
if err != nil {
t.Fatal(err)
}
s := c.String()
if s != string(data) {
t.Errorf("mismatch: got %q\nwant %q", s, string(data))
}
}
var shellIncludeFile = []byte(`
# This host should not exist, so we can use it for test purposes / it won't
# interfere with any other configurations.
Host kevinburke.ssh_config.test.example.com
Port 4567
`)
func TestIncludeShellHomeDirectory(t *testing.T) {
if testing.Short() {
t.Skip("skipping fs write in short mode")
}
testPath := filepath.Join(homedir(), "kevinburke-ssh-config-shell-include")
err := os.WriteFile(testPath, shellIncludeFile, 0644)
if err != nil {
t.Skipf("couldn't write SSH config file: %v", err.Error())
}
defer os.Remove(testPath)
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/include-shell"),
}
val := us.Get("kevinburke.ssh_config.test.example.com", "Port")
if val != "4567" {
t.Errorf("expected to find Port=4567 in included file, got %q", val)
}
}
var matchTests = []struct {
in []string
alias string
want bool
}{
{[]string{"*"}, "any.test", true},
{[]string{"a", "b", "*", "c"}, "any.test", true},
{[]string{"a", "b", "c"}, "any.test", false},
{[]string{"any.test"}, "any1test", false},
{[]string{"192.168.0.?"}, "192.168.0.1", true},
{[]string{"192.168.0.?"}, "192.168.0.10", false},
{[]string{"*.co.uk"}, "bbc.co.uk", true},
{[]string{"*.co.uk"}, "subdomain.bbc.co.uk", true},
{[]string{"*.*.co.uk"}, "bbc.co.uk", false},
{[]string{"*.*.co.uk"}, "subdomain.bbc.co.uk", true},
{[]string{"*.example.com", "!*.dialup.example.com", "foo.dialup.example.com"}, "foo.dialup.example.com", false},
{[]string{"test.*", "!test.host"}, "test.host", false},
}
func TestMatches(t *testing.T) {
for _, tt := range matchTests {
patterns := make([]*Pattern, len(tt.in))
for i := range tt.in {
pat, err := NewPattern(tt.in[i])
if err != nil {
t.Fatalf("error compiling pattern %s: %v", tt.in[i], err)
}
patterns[i] = pat
}
host := &Host{
Patterns: patterns,
}
got := host.Matches(tt.alias)
if got != tt.want {
t.Errorf("host(%q).Matches(%q): got %v, want %v", tt.in, tt.alias, got, tt.want)
}
}
}
func TestMatchExecUnsupported(t *testing.T) {
config := `Match Exec "echo hello"
Port 2222`
_, err := Decode(strings.NewReader(config))
if err == nil {
t.Fatal("expected Match Exec to error, didn't")
}
if !strings.Contains(err.Error(), "ssh_config: Match Exec is not supported") {
t.Errorf("wrong error: %v", err)
}
}
func TestIndexInRange(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config4"),
}
user, err := us.GetStrict("wap", "User")
if err != nil {
t.Fatal(err)
}
if user != "root" {
t.Errorf("expected User to be %q, got %q", "root", user)
}
}
func TestDosLinesEndingsDecode(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/dos-lines"),
}
user, err := us.GetStrict("wap", "User")
if err != nil {
t.Fatal(err)
}
if user != "root" {
t.Errorf("expected User to be %q, got %q", "root", user)
}
host, err := us.GetStrict("wap2", "HostName")
if err != nil {
t.Fatal(err)
}
if host != "8.8.8.8" {
t.Errorf("expected HostName to be %q, got %q", "8.8.8.8", host)
}
}
func TestNoTrailingNewline(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config-no-ending-newline"),
systemConfigFinder: nullConfigFinder,
}
port, err := us.GetStrict("example", "Port")
if err != nil {
t.Fatal(err)
}
if port != "4242" {
t.Errorf("wrong port: got %q want 4242", port)
}
}
func TestEOLCommentSpacing(t *testing.T) {
// Reproduces issue #50: programmatically created Host with EOLComment
// should have a space before the '#', not "Host foo#comment".
pattern, err := NewPattern("example")
if err != nil {
t.Fatal(err)
}
host := &Host{
Patterns: []*Pattern{pattern},
Nodes: []Node{
&KV{Key: " Hostname", Value: "1.2.3.4"},
},
}
host.EOLComment = "my comment"
got := host.String()
if !strings.Contains(got, "Host example #my comment") {
t.Errorf("expected space before comment, got %q", got)
}
// Same issue for KV: programmatically created KV with Comment
kv := &KV{Key: " Port", Value: "22", Comment: "ssh port"}
got = kv.String()
if !strings.Contains(got, "22 #ssh port") {
t.Errorf("expected space before KV comment, got %q", got)
}
}
func TestCustomFinder(t *testing.T) {
us := &UserSettings{}
us.ConfigFinder(func() string {
return "testdata/config1"
})
val := us.Get("wap", "User")
if val != "root" {
t.Errorf("expected to find User root, got %q", val)
}
}