credentials/tls: reject connections with ALPN disabled (#7184)

diff --git a/credentials/tls.go b/credentials/tls.go
index 5dafd34..4114358 100644
--- a/credentials/tls.go
+++ b/credentials/tls.go
@@ -27,9 +27,13 @@
 	"net/url"
 	"os"
 
+	"google.golang.org/grpc/grpclog"
 	credinternal "google.golang.org/grpc/internal/credentials"
+	"google.golang.org/grpc/internal/envconfig"
 )
 
+var logger = grpclog.Component("credentials")
+
 // TLSInfo contains the auth information for a TLS authenticated connection.
 // It implements the AuthInfo interface.
 type TLSInfo struct {
@@ -112,6 +116,22 @@
 		conn.Close()
 		return nil, nil, ctx.Err()
 	}
+
+	// The negotiated protocol can be either of the following:
+	// 1. h2: When the server supports ALPN. Only HTTP/2 can be negotiated since
+	//    it is the only protocol advertised by the client during the handshake.
+	//    The tls library ensures that the server chooses a protocol advertised
+	//    by the client.
+	// 2. "" (empty string): If the server doesn't support ALPN. ALPN is a requirement
+	//    for using HTTP/2 over TLS. We can terminate the connection immediately.
+	np := conn.ConnectionState().NegotiatedProtocol
+	if np == "" {
+		if envconfig.EnforceALPNEnabled {
+			conn.Close()
+			return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
+		}
+		logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName)
+	}
 	tlsInfo := TLSInfo{
 		State: conn.ConnectionState(),
 		CommonAuthInfo: CommonAuthInfo{
@@ -131,8 +151,20 @@
 		conn.Close()
 		return nil, nil, err
 	}
+	cs := conn.ConnectionState()
+	// The negotiated application protocol can be empty only if the client doesn't
+	// support ALPN. In such cases, we can close the connection since ALPN is required
+	// for using HTTP/2 over TLS.
+	if cs.NegotiatedProtocol == "" {
+		if envconfig.EnforceALPNEnabled {
+			conn.Close()
+			return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
+		} else if logger.V(2) {
+			logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases")
+		}
+	}
 	tlsInfo := TLSInfo{
-		State: conn.ConnectionState(),
+		State: cs,
 		CommonAuthInfo: CommonAuthInfo{
 			SecurityLevel: PrivacyAndIntegrity,
 		},
diff --git a/credentials/tls_ext_test.go b/credentials/tls_ext_test.go
index 2541b8d..c817777 100644
--- a/credentials/tls_ext_test.go
+++ b/credentials/tls_ext_test.go
@@ -23,6 +23,7 @@
 	"crypto/tls"
 	"crypto/x509"
 	"fmt"
+	"net"
 	"os"
 	"strings"
 	"testing"
@@ -31,6 +32,7 @@
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/credentials"
+	"google.golang.org/grpc/internal/envconfig"
 	"google.golang.org/grpc/internal/grpctest"
 	"google.golang.org/grpc/internal/stubserver"
 	"google.golang.org/grpc/status"
@@ -236,3 +238,160 @@
 		t.Fatalf("EmptyCall err = %v; want <nil>", err)
 	}
 }
+
+// TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
+// connecting to a server that doesn't support ALPN.
+func (s) TestTLS_DisabledALPNClient(t *testing.T) {
+	initialVal := envconfig.EnforceALPNEnabled
+	defer func() {
+		envconfig.EnforceALPNEnabled = initialVal
+	}()
+
+	tests := []struct {
+		name         string
+		alpnEnforced bool
+		wantErr      bool
+	}{
+		{
+			name:         "enforced",
+			alpnEnforced: true,
+			wantErr:      true,
+		},
+		{
+			name: "not_enforced",
+		},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			envconfig.EnforceALPNEnabled = tc.alpnEnforced
+
+			listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
+				Certificates: []tls.Certificate{serverCert},
+				NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
+			})
+			if err != nil {
+				t.Fatalf("Error starting TLS server: %v", err)
+			}
+
+			errCh := make(chan error, 1)
+			go func() {
+				conn, err := listener.Accept()
+				if err != nil {
+					errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
+				} else {
+					// The first write to the TLS listener initiates the TLS handshake.
+					conn.Write([]byte("Hello, World!"))
+					conn.Close()
+				}
+				close(errCh)
+			}()
+
+			serverAddr := listener.Addr().String()
+			conn, err := net.Dial("tcp", serverAddr)
+			if err != nil {
+				t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
+			}
+			defer conn.Close()
+
+			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+			defer cancel()
+
+			clientCfg := tls.Config{
+				ServerName: serverName,
+				RootCAs:    certPool,
+				NextProtos: []string{"h2"},
+			}
+			_, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn)
+
+			if gotErr := (err != nil); gotErr != tc.wantErr {
+				t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
+			}
+
+			select {
+			case err := <-errCh:
+				if err != nil {
+					t.Fatalf("Unexpected error received from server: %v", err)
+				}
+			case <-ctx.Done():
+				t.Fatalf("Timeout waiting for error from server")
+			}
+		})
+	}
+}
+
+// TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
+// accepting a request from a client that doesn't support ALPN.
+func (s) TestTLS_DisabledALPNServer(t *testing.T) {
+	initialVal := envconfig.EnforceALPNEnabled
+	defer func() {
+		envconfig.EnforceALPNEnabled = initialVal
+	}()
+
+	tests := []struct {
+		name         string
+		alpnEnforced bool
+		wantErr      bool
+	}{
+		{
+			name:         "enforced",
+			alpnEnforced: true,
+			wantErr:      true,
+		},
+		{
+			name: "not_enforced",
+		},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			envconfig.EnforceALPNEnabled = tc.alpnEnforced
+
+			listener, err := net.Listen("tcp", "localhost:0")
+			if err != nil {
+				t.Fatalf("Error starting server: %v", err)
+			}
+
+			errCh := make(chan error, 1)
+			go func() {
+				conn, err := listener.Accept()
+				if err != nil {
+					errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
+					return
+				}
+				defer conn.Close()
+				serverCfg := tls.Config{
+					Certificates: []tls.Certificate{serverCert},
+					NextProtos:   []string{"h2"},
+				}
+				_, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn)
+				if gotErr := (err != nil); gotErr != tc.wantErr {
+					t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
+				}
+				close(errCh)
+			}()
+
+			serverAddr := listener.Addr().String()
+			clientCfg := &tls.Config{
+				Certificates: []tls.Certificate{serverCert},
+				NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
+				RootCAs:      certPool,
+				ServerName:   serverName,
+			}
+			conn, err := tls.Dial("tcp", serverAddr, clientCfg)
+			if err != nil {
+				t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
+			}
+			defer conn.Close()
+
+			select {
+			case <-time.After(defaultTestTimeout):
+				t.Fatal("Timed out waiting for completion")
+			case err := <-errCh:
+				if err != nil {
+					t.Fatalf("Unexpected server error: %v", err)
+				}
+			}
+		})
+	}
+}
diff --git a/internal/envconfig/envconfig.go b/internal/envconfig/envconfig.go
index 9c915d9..d906487 100644
--- a/internal/envconfig/envconfig.go
+++ b/internal/envconfig/envconfig.go
@@ -40,6 +40,12 @@
 	// ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS
 	// handshakes that can be performed.
 	ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100)
+	// EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled
+	// should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this
+	// option is present for backward compatibility. This option may be overridden
+	// by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
+	// or "false".
+	EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", false)
 )
 
 func boolFromEnv(envVar string, def bool) bool {