Merge pull request #100951 from saschagrunert/automated-cherry-pick-of-#99839-upstream-release-1.21

Automated cherry pick of #99839: Cleanup portforward streams after their usage

Kubernetes-commit: 9745a35d15c40607d94424776cc84a130c64d75f
diff --git a/pkg/util/httpstream/httpstream.go b/pkg/util/httpstream/httpstream.go
index 00ce5f7..32f0757 100644
--- a/pkg/util/httpstream/httpstream.go
+++ b/pkg/util/httpstream/httpstream.go
@@ -78,6 +78,8 @@
 	// SetIdleTimeout sets the amount of time the connection may remain idle before
 	// it is automatically closed.
 	SetIdleTimeout(timeout time.Duration)
+	// RemoveStreams can be used to remove a set of streams from the Connection.
+	RemoveStreams(streams ...Stream)
 }
 
 // Stream represents a bidirectional communications channel that is part of an
diff --git a/pkg/util/httpstream/spdy/connection.go b/pkg/util/httpstream/spdy/connection.go
index 21b2568..3da7457 100644
--- a/pkg/util/httpstream/spdy/connection.go
+++ b/pkg/util/httpstream/spdy/connection.go
@@ -31,7 +31,7 @@
 // streams.
 type connection struct {
 	conn             *spdystream.Connection
-	streams          []httpstream.Stream
+	streams          map[uint32]httpstream.Stream
 	streamLock       sync.Mutex
 	newStreamHandler httpstream.NewStreamHandler
 	ping             func() (time.Duration, error)
@@ -85,7 +85,12 @@
 // will be invoked when the server receives a newly created stream from the
 // client.
 func newConnection(conn *spdystream.Connection, newStreamHandler httpstream.NewStreamHandler, pingPeriod time.Duration, pingFn func() (time.Duration, error)) httpstream.Connection {
-	c := &connection{conn: conn, newStreamHandler: newStreamHandler, ping: pingFn}
+	c := &connection{
+		conn:             conn,
+		newStreamHandler: newStreamHandler,
+		ping:             pingFn,
+		streams:          make(map[uint32]httpstream.Stream),
+	}
 	go conn.Serve(c.newSpdyStream)
 	if pingPeriod > 0 && pingFn != nil {
 		go c.sendPings(pingPeriod)
@@ -105,7 +110,7 @@
 		// calling Reset instead of Close ensures that all streams are fully torn down
 		s.Reset()
 	}
-	c.streams = make([]httpstream.Stream, 0)
+	c.streams = make(map[uint32]httpstream.Stream, 0)
 	c.streamLock.Unlock()
 
 	// now that all streams are fully torn down, it's safe to call close on the underlying connection,
@@ -114,6 +119,15 @@
 	return c.conn.Close()
 }
 
+// RemoveStreams can be used to removes a set of streams from the Connection.
+func (c *connection) RemoveStreams(streams ...httpstream.Stream) {
+	c.streamLock.Lock()
+	for _, stream := range streams {
+		delete(c.streams, stream.Identifier())
+	}
+	c.streamLock.Unlock()
+}
+
 // CreateStream creates a new stream with the specified headers and registers
 // it with the connection.
 func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error) {
@@ -133,7 +147,7 @@
 // it owns.
 func (c *connection) registerStream(s httpstream.Stream) {
 	c.streamLock.Lock()
-	c.streams = append(c.streams, s)
+	c.streams[s.Identifier()] = s
 	c.streamLock.Unlock()
 }
 
diff --git a/pkg/util/httpstream/spdy/connection_test.go b/pkg/util/httpstream/spdy/connection_test.go
index 9e551dd..edef917 100644
--- a/pkg/util/httpstream/spdy/connection_test.go
+++ b/pkg/util/httpstream/spdy/connection_test.go
@@ -290,3 +290,41 @@
 		t.Errorf("timed out waiting for server to exit")
 	}
 }
+
+type fakeStream struct{ id uint32 }
+
+func (*fakeStream) Read(p []byte) (int, error)  { return 0, nil }
+func (*fakeStream) Write(p []byte) (int, error) { return 0, nil }
+func (*fakeStream) Close() error                { return nil }
+func (*fakeStream) Reset() error                { return nil }
+func (*fakeStream) Headers() http.Header        { return nil }
+func (f *fakeStream) Identifier() uint32        { return f.id }
+
+func TestConnectionRemoveStreams(t *testing.T) {
+	c := &connection{streams: make(map[uint32]httpstream.Stream)}
+	stream0 := &fakeStream{id: 0}
+	stream1 := &fakeStream{id: 1}
+	stream2 := &fakeStream{id: 2}
+
+	c.registerStream(stream0)
+	c.registerStream(stream1)
+
+	if len(c.streams) != 2 {
+		t.Fatalf("should have two streams, has %d", len(c.streams))
+	}
+
+	// not exists
+	c.RemoveStreams(stream2)
+
+	if len(c.streams) != 2 {
+		t.Fatalf("should have two streams, has %d", len(c.streams))
+	}
+
+	// remove all existing
+	c.RemoveStreams(stream0, stream1)
+
+	if len(c.streams) != 0 {
+		t.Fatalf("should not have any streams, has %d", len(c.streams))
+	}
+
+}