| package socket |
| |
| import ( |
| "errors" |
| "io" |
| "io/fs" |
| "net" |
| "os" |
| "runtime" |
| "strings" |
| "sync/atomic" |
| "testing" |
| "time" |
| |
| "gotest.tools/v3/assert" |
| "gotest.tools/v3/poll" |
| ) |
| |
| func TestPluginServer(t *testing.T) { |
| t.Run("connection closes with EOF when server closes", func(t *testing.T) { |
| called := make(chan struct{}) |
| srv, err := NewPluginServer(func(_ net.Conn) { close(called) }) |
| assert.NilError(t, err) |
| assert.Assert(t, srv != nil, "returned nil server but no error") |
| |
| addr, err := net.ResolveUnixAddr("unix", srv.Addr().String()) |
| assert.NilError(t, err, "failed to resolve server address") |
| |
| conn, err := net.DialUnix("unix", nil, addr) |
| assert.NilError(t, err, "failed to dial returned server") |
| defer conn.Close() |
| |
| done := make(chan error, 1) |
| go func() { |
| _, err := conn.Read(make([]byte, 1)) |
| done <- err |
| }() |
| |
| select { |
| case <-called: |
| case <-time.After(10 * time.Millisecond): |
| t.Fatal("handler not called") |
| } |
| |
| srv.Close() |
| |
| select { |
| case err := <-done: |
| if !errors.Is(err, io.EOF) { |
| t.Fatalf("exepcted EOF error, got: %v", err) |
| } |
| case <-time.After(10 * time.Millisecond): |
| } |
| }) |
| |
| t.Run("allows reconnects", func(t *testing.T) { |
| var calls int32 |
| h := func(_ net.Conn) { |
| atomic.AddInt32(&calls, 1) |
| } |
| |
| srv, err := NewPluginServer(h) |
| assert.NilError(t, err) |
| defer srv.Close() |
| |
| assert.Check(t, srv.Addr() != nil, "returned nil addr but no error") |
| |
| addr, err := net.ResolveUnixAddr("unix", srv.Addr().String()) |
| assert.NilError(t, err, "failed to resolve server address") |
| |
| waitForCalls := func(n int) { |
| poll.WaitOn(t, func(t poll.LogT) poll.Result { |
| if atomic.LoadInt32(&calls) == int32(n) { |
| return poll.Success() |
| } |
| return poll.Continue("waiting for handler to be called") |
| }) |
| } |
| |
| otherConn, err := net.DialUnix("unix", nil, addr) |
| assert.NilError(t, err, "failed to dial returned server") |
| otherConn.Close() |
| waitForCalls(1) |
| |
| conn, err := net.DialUnix("unix", nil, addr) |
| assert.NilError(t, err, "failed to redial server") |
| defer conn.Close() |
| waitForCalls(2) |
| |
| // and again but don't close the existing connection |
| conn2, err := net.DialUnix("unix", nil, addr) |
| assert.NilError(t, err, "failed to redial server") |
| defer conn2.Close() |
| waitForCalls(3) |
| |
| srv.Close() |
| |
| // now make sure we get EOF on the existing connections |
| buf := make([]byte, 1) |
| _, err = conn.Read(buf) |
| assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err) |
| |
| _, err = conn2.Read(buf) |
| assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err) |
| }) |
| |
| t.Run("does not leak sockets to local directory", func(t *testing.T) { |
| srv, err := NewPluginServer(nil) |
| assert.NilError(t, err) |
| assert.Check(t, srv != nil, "returned nil server but no error") |
| checkDirNoNewPluginServer(t) |
| |
| addr, err := net.ResolveUnixAddr("unix", srv.Addr().String()) |
| assert.NilError(t, err, "failed to resolve server address") |
| |
| _, err = net.DialUnix("unix", nil, addr) |
| assert.NilError(t, err, "failed to dial returned server") |
| checkDirNoNewPluginServer(t) |
| }) |
| |
| t.Run("does not panic on Close if server is nil", func(t *testing.T) { |
| var srv *PluginServer |
| defer func() { |
| if r := recover(); r != nil { |
| t.Errorf("panicked on Close") |
| } |
| }() |
| |
| err := srv.Close() |
| assert.NilError(t, err) |
| }) |
| } |
| |
| func checkDirNoNewPluginServer(t *testing.T) { |
| t.Helper() |
| |
| files, err := os.ReadDir(".") |
| assert.NilError(t, err, "failed to list files in dir to check for leaked sockets") |
| |
| for _, f := range files { |
| info, err := f.Info() |
| assert.NilError(t, err, "failed to check file info") |
| // check for a socket with `docker_cli_` in the name (from `SetupConn()`) |
| if strings.Contains(f.Name(), "docker_cli_") && info.Mode().Type() == fs.ModeSocket { |
| t.Fatal("found socket in a local directory") |
| } |
| } |
| } |
| |
| func TestConnectAndWait(t *testing.T) { |
| t.Run("calls cancel func on EOF", func(t *testing.T) { |
| srv, err := NewPluginServer(nil) |
| assert.NilError(t, err, "failed to setup server") |
| defer srv.Close() |
| |
| done := make(chan struct{}) |
| t.Setenv(EnvKey, srv.Addr().String()) |
| cancelFunc := func() { |
| done <- struct{}{} |
| } |
| ConnectAndWait(cancelFunc) |
| |
| select { |
| case <-done: |
| t.Fatal("unexpectedly done") |
| default: |
| } |
| |
| srv.Close() |
| |
| select { |
| case <-done: |
| case <-time.After(10 * time.Millisecond): |
| t.Fatal("cancel function not closed after 10ms") |
| } |
| }) |
| |
| // TODO: this test cannot be executed with `t.Parallel()`, due to |
| // relying on goroutine numbers to ensure correct behaviour |
| t.Run("connect goroutine exits after EOF", func(t *testing.T) { |
| srv, err := NewPluginServer(nil) |
| assert.NilError(t, err, "failed to setup server") |
| |
| defer srv.Close() |
| |
| t.Setenv(EnvKey, srv.Addr().String()) |
| numGoroutines := runtime.NumGoroutine() |
| |
| ConnectAndWait(func() {}) |
| assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1) |
| |
| srv.Close() |
| |
| poll.WaitOn(t, func(t poll.LogT) poll.Result { |
| if runtime.NumGoroutine() > numGoroutines+1 { |
| return poll.Continue("waiting for connect goroutine to exit") |
| } |
| return poll.Success() |
| }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond)) |
| }) |
| } |