Merge pull request #100951 from saschagrunert/automated-cherry-pick-of-#99839-upstream-release-1.21
Automated cherry pick of #99839: Cleanup portforward streams after their usage
Kubernetes-commit: 9745a35d15c40607d94424776cc84a130c64d75f
diff --git a/pkg/util/httpstream/httpstream.go b/pkg/util/httpstream/httpstream.go
index 00ce5f7..32f0757 100644
--- a/pkg/util/httpstream/httpstream.go
+++ b/pkg/util/httpstream/httpstream.go
@@ -78,6 +78,8 @@
// SetIdleTimeout sets the amount of time the connection may remain idle before
// it is automatically closed.
SetIdleTimeout(timeout time.Duration)
+ // RemoveStreams can be used to remove a set of streams from the Connection.
+ RemoveStreams(streams ...Stream)
}
// Stream represents a bidirectional communications channel that is part of an
diff --git a/pkg/util/httpstream/spdy/connection.go b/pkg/util/httpstream/spdy/connection.go
index 21b2568..3da7457 100644
--- a/pkg/util/httpstream/spdy/connection.go
+++ b/pkg/util/httpstream/spdy/connection.go
@@ -31,7 +31,7 @@
// streams.
type connection struct {
conn *spdystream.Connection
- streams []httpstream.Stream
+ streams map[uint32]httpstream.Stream
streamLock sync.Mutex
newStreamHandler httpstream.NewStreamHandler
ping func() (time.Duration, error)
@@ -85,7 +85,12 @@
// will be invoked when the server receives a newly created stream from the
// client.
func newConnection(conn *spdystream.Connection, newStreamHandler httpstream.NewStreamHandler, pingPeriod time.Duration, pingFn func() (time.Duration, error)) httpstream.Connection {
- c := &connection{conn: conn, newStreamHandler: newStreamHandler, ping: pingFn}
+ c := &connection{
+ conn: conn,
+ newStreamHandler: newStreamHandler,
+ ping: pingFn,
+ streams: make(map[uint32]httpstream.Stream),
+ }
go conn.Serve(c.newSpdyStream)
if pingPeriod > 0 && pingFn != nil {
go c.sendPings(pingPeriod)
@@ -105,7 +110,7 @@
// calling Reset instead of Close ensures that all streams are fully torn down
s.Reset()
}
- c.streams = make([]httpstream.Stream, 0)
+ c.streams = make(map[uint32]httpstream.Stream, 0)
c.streamLock.Unlock()
// now that all streams are fully torn down, it's safe to call close on the underlying connection,
@@ -114,6 +119,15 @@
return c.conn.Close()
}
+// RemoveStreams can be used to removes a set of streams from the Connection.
+func (c *connection) RemoveStreams(streams ...httpstream.Stream) {
+ c.streamLock.Lock()
+ for _, stream := range streams {
+ delete(c.streams, stream.Identifier())
+ }
+ c.streamLock.Unlock()
+}
+
// CreateStream creates a new stream with the specified headers and registers
// it with the connection.
func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error) {
@@ -133,7 +147,7 @@
// it owns.
func (c *connection) registerStream(s httpstream.Stream) {
c.streamLock.Lock()
- c.streams = append(c.streams, s)
+ c.streams[s.Identifier()] = s
c.streamLock.Unlock()
}
diff --git a/pkg/util/httpstream/spdy/connection_test.go b/pkg/util/httpstream/spdy/connection_test.go
index 9e551dd..edef917 100644
--- a/pkg/util/httpstream/spdy/connection_test.go
+++ b/pkg/util/httpstream/spdy/connection_test.go
@@ -290,3 +290,41 @@
t.Errorf("timed out waiting for server to exit")
}
}
+
+type fakeStream struct{ id uint32 }
+
+func (*fakeStream) Read(p []byte) (int, error) { return 0, nil }
+func (*fakeStream) Write(p []byte) (int, error) { return 0, nil }
+func (*fakeStream) Close() error { return nil }
+func (*fakeStream) Reset() error { return nil }
+func (*fakeStream) Headers() http.Header { return nil }
+func (f *fakeStream) Identifier() uint32 { return f.id }
+
+func TestConnectionRemoveStreams(t *testing.T) {
+ c := &connection{streams: make(map[uint32]httpstream.Stream)}
+ stream0 := &fakeStream{id: 0}
+ stream1 := &fakeStream{id: 1}
+ stream2 := &fakeStream{id: 2}
+
+ c.registerStream(stream0)
+ c.registerStream(stream1)
+
+ if len(c.streams) != 2 {
+ t.Fatalf("should have two streams, has %d", len(c.streams))
+ }
+
+ // not exists
+ c.RemoveStreams(stream2)
+
+ if len(c.streams) != 2 {
+ t.Fatalf("should have two streams, has %d", len(c.streams))
+ }
+
+ // remove all existing
+ c.RemoveStreams(stream0, stream1)
+
+ if len(c.streams) != 0 {
+ t.Fatalf("should not have any streams, has %d", len(c.streams))
+ }
+
+}