xds/clusterimpl: update UpdateClientConnState to handle updates synchronously (#7533)

diff --git a/xds/internal/balancer/clusterimpl/clusterimpl.go b/xds/internal/balancer/clusterimpl/clusterimpl.go
index 9058f0d..c8017c7 100644
--- a/xds/internal/balancer/clusterimpl/clusterimpl.go
+++ b/xds/internal/balancer/clusterimpl/clusterimpl.go
@@ -24,6 +24,7 @@
 package clusterimpl
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"sync"
@@ -33,7 +34,6 @@
 	"google.golang.org/grpc/connectivity"
 	"google.golang.org/grpc/internal"
 	"google.golang.org/grpc/internal/balancer/gracefulswitch"
-	"google.golang.org/grpc/internal/buffer"
 	"google.golang.org/grpc/internal/grpclog"
 	"google.golang.org/grpc/internal/grpcsync"
 	"google.golang.org/grpc/internal/pretty"
@@ -53,7 +53,10 @@
 	defaultRequestCountMax = 1024
 )
 
-var connectedAddress = internal.ConnectedAddress.(func(balancer.SubConnState) resolver.Address)
+var (
+	connectedAddress  = internal.ConnectedAddress.(func(balancer.SubConnState) resolver.Address)
+	errBalancerClosed = fmt.Errorf("%s LB policy is closed", Name)
+)
 
 func init() {
 	balancer.Register(bb{})
@@ -62,18 +65,17 @@
 type bb struct{}
 
 func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
+	ctx, cancel := context.WithCancel(context.Background())
 	b := &clusterImplBalancer{
-		ClientConn:      cc,
-		bOpts:           bOpts,
-		closed:          grpcsync.NewEvent(),
-		done:            grpcsync.NewEvent(),
-		loadWrapper:     loadstore.NewWrapper(),
-		pickerUpdateCh:  buffer.NewUnbounded(),
-		requestCountMax: defaultRequestCountMax,
+		ClientConn:       cc,
+		bOpts:            bOpts,
+		loadWrapper:      loadstore.NewWrapper(),
+		requestCountMax:  defaultRequestCountMax,
+		serializer:       grpcsync.NewCallbackSerializer(ctx),
+		serializerCancel: cancel,
 	}
 	b.logger = prefixLogger(b)
 	b.child = gracefulswitch.NewBalancer(b, bOpts)
-	go b.run()
 	b.logger.Infof("Created")
 	return b
 }
@@ -89,18 +91,6 @@
 type clusterImplBalancer struct {
 	balancer.ClientConn
 
-	// mu guarantees mutual exclusion between Close() and handling of picker
-	// update to the parent ClientConn in run(). It's to make sure that the
-	// run() goroutine doesn't send picker update to parent after the balancer
-	// is closed.
-	//
-	// It's only used by the run() goroutine, but not the other exported
-	// functions. Because the exported functions are guaranteed to be
-	// synchronized with Close().
-	mu     sync.Mutex
-	closed *grpcsync.Event
-	done   *grpcsync.Event
-
 	bOpts     balancer.BuildOptions
 	logger    *grpclog.PrefixLogger
 	xdsClient xdsclient.XDSClient
@@ -115,10 +105,11 @@
 	clusterNameMu sync.Mutex
 	clusterName   string
 
+	serializer       *grpcsync.CallbackSerializer
+	serializerCancel context.CancelFunc
+
 	// childState/drops/requestCounter keeps the state used by the most recently
-	// generated picker. All fields can only be accessed in run(). And run() is
-	// the only goroutine that sends picker to the parent ClientConn. All
-	// requests to update picker need to be sent to pickerUpdateCh.
+	// generated picker.
 	childState            balancer.State
 	dropCategories        []DropConfig // The categories for drops.
 	drops                 []*dropper
@@ -127,7 +118,6 @@
 	requestCounter        *xdsclient.ClusterRequestsCounter
 	requestCountMax       uint32
 	telemetryLabels       map[string]string
-	pickerUpdateCh        *buffer.Unbounded
 }
 
 // updateLoadStore checks the config for load store, and decides whether it
@@ -208,14 +198,9 @@
 	return nil
 }
 
-func (b *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState) error {
-	if b.closed.HasFired() {
-		b.logger.Warningf("xds: received ClientConnState {%+v} after clusterImplBalancer was closed", s)
-		return nil
-	}
-
+func (b *clusterImplBalancer) updateClientConnState(s balancer.ClientConnState) error {
 	if b.logger.V(2) {
-		b.logger.Infof("Received update from resolver, balancer config: %s", pretty.ToJSON(s.BalancerConfig))
+		b.logger.Infof("Received configuration: %s", pretty.ToJSON(s.BalancerConfig))
 	}
 	newConfig, ok := s.BalancerConfig.(*LBConfig)
 	if !ok {
@@ -227,7 +212,7 @@
 	// it.
 	bb := balancer.Get(newConfig.ChildPolicy.Name)
 	if bb == nil {
-		return fmt.Errorf("balancer %q not registered", newConfig.ChildPolicy.Name)
+		return fmt.Errorf("child policy %q not registered", newConfig.ChildPolicy.Name)
 	}
 
 	if b.xdsClient == nil {
@@ -253,9 +238,14 @@
 	}
 	b.config = newConfig
 
-	// Notify run() of this new config, in case drop and request counter need
-	// update (which means a new picker needs to be generated).
-	b.pickerUpdateCh.Put(newConfig)
+	b.telemetryLabels = newConfig.TelemetryLabels
+	dc := b.handleDropAndRequestCount(newConfig)
+	if dc != nil && b.childState.Picker != nil {
+		b.ClientConn.UpdateState(balancer.State{
+			ConnectivityState: b.childState.ConnectivityState,
+			Picker:            b.newPicker(dc),
+		})
+	}
 
 	// Addresses and sub-balancer config are sent to sub-balancer.
 	return b.child.UpdateClientConnState(balancer.ClientConnState{
@@ -264,20 +254,28 @@
 	})
 }
 
-func (b *clusterImplBalancer) ResolverError(err error) {
-	if b.closed.HasFired() {
-		b.logger.Warningf("xds: received resolver error {%+v} after clusterImplBalancer was closed", err)
-		return
+func (b *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState) error {
+	// Handle the update in a blocking fashion.
+	errCh := make(chan error, 1)
+	callback := func(context.Context) {
+		errCh <- b.updateClientConnState(s)
 	}
-	b.child.ResolverError(err)
+	onFailure := func() {
+		// An attempt to schedule callback fails only when an update is received
+		// after Close().
+		errCh <- errBalancerClosed
+	}
+	b.serializer.ScheduleOr(callback, onFailure)
+	return <-errCh
+}
+
+func (b *clusterImplBalancer) ResolverError(err error) {
+	b.serializer.TrySchedule(func(context.Context) {
+		b.child.ResolverError(err)
+	})
 }
 
 func (b *clusterImplBalancer) updateSubConnState(sc balancer.SubConn, s balancer.SubConnState, cb func(balancer.SubConnState)) {
-	if b.closed.HasFired() {
-		b.logger.Warningf("xds: received subconn state change {%+v, %+v} after clusterImplBalancer was closed", sc, s)
-		return
-	}
-
 	// Trigger re-resolution when a SubConn turns transient failure. This is
 	// necessary for the LogicalDNS in cluster_resolver policy to re-resolve.
 	//
@@ -299,26 +297,40 @@
 }
 
 func (b *clusterImplBalancer) Close() {
-	b.mu.Lock()
-	b.closed.Fire()
-	b.mu.Unlock()
+	b.serializer.TrySchedule(func(ctx context.Context) {
+		b.child.Close()
+		b.childState = balancer.State{}
 
-	b.child.Close()
-	b.childState = balancer.State{}
-	b.pickerUpdateCh.Close()
-	<-b.done.Done()
-	b.logger.Infof("Shutdown")
+		if b.cancelLoadReport != nil {
+			b.cancelLoadReport()
+			b.cancelLoadReport = nil
+		}
+		b.logger.Infof("Shutdown")
+	})
+	b.serializerCancel()
+	<-b.serializer.Done()
 }
 
 func (b *clusterImplBalancer) ExitIdle() {
-	b.child.ExitIdle()
+	b.serializer.TrySchedule(func(context.Context) {
+		b.child.ExitIdle()
+	})
 }
 
 // Override methods to accept updates from the child LB.
 
 func (b *clusterImplBalancer) UpdateState(state balancer.State) {
-	// Instead of updating parent ClientConn inline, send state to run().
-	b.pickerUpdateCh.Put(state)
+	b.serializer.TrySchedule(func(context.Context) {
+		b.childState = state
+		b.ClientConn.UpdateState(balancer.State{
+			ConnectivityState: b.childState.ConnectivityState,
+			Picker: b.newPicker(&dropConfigs{
+				drops:           b.drops,
+				requestCounter:  b.requestCounter,
+				requestCountMax: b.requestCountMax,
+			}),
+		})
+	})
 }
 
 func (b *clusterImplBalancer) setClusterName(n string) {
@@ -370,21 +382,23 @@
 	scw := &scWrapper{}
 	oldListener := opts.StateListener
 	opts.StateListener = func(state balancer.SubConnState) {
-		b.updateSubConnState(sc, state, oldListener)
-		if state.ConnectivityState != connectivity.Ready {
-			return
-		}
-		// Read connected address and call updateLocalityID() based on the connected
-		// address's locality. https://github.com/grpc/grpc-go/issues/7339
-		addr := connectedAddress(state)
-		lID := xdsinternal.GetLocalityID(addr)
-		if lID.Empty() {
-			if b.logger.V(2) {
-				b.logger.Infof("Locality ID for %s unexpectedly empty", addr)
+		b.serializer.TrySchedule(func(context.Context) {
+			b.updateSubConnState(sc, state, oldListener)
+			if state.ConnectivityState != connectivity.Ready {
+				return
 			}
-			return
-		}
-		scw.updateLocalityID(lID)
+			// Read connected address and call updateLocalityID() based on the connected
+			// address's locality. https://github.com/grpc/grpc-go/issues/7339
+			addr := connectedAddress(state)
+			lID := xdsinternal.GetLocalityID(addr)
+			if lID.Empty() {
+				if b.logger.V(2) {
+					b.logger.Infof("Locality ID for %s unexpectedly empty", addr)
+				}
+				return
+			}
+			scw.updateLocalityID(lID)
+		})
 	}
 	sc, err := b.ClientConn.NewSubConn(newAddrs, opts)
 	if err != nil {
@@ -464,49 +478,3 @@
 		requestCountMax: b.requestCountMax,
 	}
 }
-
-func (b *clusterImplBalancer) run() {
-	defer b.done.Fire()
-	for {
-		select {
-		case update, ok := <-b.pickerUpdateCh.Get():
-			if !ok {
-				return
-			}
-			b.pickerUpdateCh.Load()
-			b.mu.Lock()
-			if b.closed.HasFired() {
-				b.mu.Unlock()
-				return
-			}
-			switch u := update.(type) {
-			case balancer.State:
-				b.childState = u
-				b.ClientConn.UpdateState(balancer.State{
-					ConnectivityState: b.childState.ConnectivityState,
-					Picker: b.newPicker(&dropConfigs{
-						drops:           b.drops,
-						requestCounter:  b.requestCounter,
-						requestCountMax: b.requestCountMax,
-					}),
-				})
-			case *LBConfig:
-				b.telemetryLabels = u.TelemetryLabels
-				dc := b.handleDropAndRequestCount(u)
-				if dc != nil && b.childState.Picker != nil {
-					b.ClientConn.UpdateState(balancer.State{
-						ConnectivityState: b.childState.ConnectivityState,
-						Picker:            b.newPicker(dc),
-					})
-				}
-			}
-			b.mu.Unlock()
-		case <-b.closed.Done():
-			if b.cancelLoadReport != nil {
-				b.cancelLoadReport()
-				b.cancelLoadReport = nil
-			}
-			return
-		}
-	}
-}