Apply HSTS to WebSocket connections.

With this change, ws: connections to hosts which have an existing HSTS
pin will be automatically changed to use wss:, ie. SSL.

In addition, Strict-Transport-Security headers that are sent from a wss:
server with a valid SSL certificate will be enforced on subsequent ws:
and http: connections to the same host.

This CL also modifies HttpNetworkTransaction to treat wss: the same as
https:.

BUG=455215, 446480
TEST=net_unittests
R=rsleevi@chromium.org, tyoshino@chromium.org

Review URL: https://codereview.chromium.org/903553005

Cr-Original-Commit-Position: refs/heads/master@{#317252}
Cr-Mirrored-From: https://chromium.googlesource.com/chromium/src
Cr-Mirrored-Commit: cb76ac67dca0a133cdfa96678ac5cd2a65af96a3
diff --git a/data/websocket/set-hsts_wsh.py b/data/websocket/set-hsts_wsh.py
new file mode 100644
index 0000000..c78a82a
--- /dev/null
+++ b/data/websocket/set-hsts_wsh.py
@@ -0,0 +1,19 @@
+# Copyright 2015 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+#
+# Add a Strict-Transport-Security header to the response.
+
+
+import json
+
+
+def web_socket_do_extra_handshake(request):
+  request.extra_headers.append(
+      ('Strict-Transport-Security', 'max-age=3600'))
+  pass
+
+
+def web_socket_transfer_data(request):
+  # Wait for closing handshake
+  request.ws_stream.receive_message()
diff --git a/http/http_network_transaction.cc b/http/http_network_transaction.cc
index e16d7c6..5ce7572 100644
--- a/http/http_network_transaction.cc
+++ b/http/http_network_transaction.cc
@@ -569,8 +569,8 @@
   OnIOComplete(ERR_HTTPS_PROXY_TUNNEL_RESPONSE);
 }
 
-bool HttpNetworkTransaction::is_https_request() const {
-  return request_->url.SchemeIs("https");
+bool HttpNetworkTransaction::IsSecureRequest() const {
+  return request_->url.SchemeIsSecure();
 }
 
 bool HttpNetworkTransaction::UsingHttpProxyWithoutTunnel() const {
@@ -969,7 +969,7 @@
   } else if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) {
     // TODO(wtc): Need a test case for this code path!
     DCHECK(stream_.get());
-    DCHECK(is_https_request());
+    DCHECK(IsSecureRequest());
     response_.cert_request_info = new SSLCertRequestInfo;
     stream_->GetSSLCertRequestInfo(response_.cert_request_info.get());
     result = HandleCertificateRequest(result);
@@ -1050,7 +1050,7 @@
   if (rv != OK)
     return rv;
 
-  if (is_https_request())
+  if (IsSecureRequest())
     stream_->GetSSLInfo(&response_.ssl_info);
 
   headers_valid_ = true;
diff --git a/http/http_network_transaction.h b/http/http_network_transaction.h
index 2c4bd93..c098d75 100644
--- a/http/http_network_transaction.h
+++ b/http/http_network_transaction.h
@@ -142,7 +142,7 @@
     STATE_NONE
   };
 
-  bool is_https_request() const;
+  bool IsSecureRequest() const;
 
   // Returns true if the request is using an HTTP(S) proxy without being
   // tunneled via the CONNECT method.
diff --git a/http/http_network_transaction_unittest.cc b/http/http_network_transaction_unittest.cc
index c3eab76..6e1ad43 100644
--- a/http/http_network_transaction_unittest.cc
+++ b/http/http_network_transaction_unittest.cc
@@ -12554,7 +12554,7 @@
     return false;
   }
 
-  void GetSSLInfo(SSLInfo* ssl_info) override { NOTREACHED(); }
+  void GetSSLInfo(SSLInfo* ssl_info) override {}
 
   void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override {
     NOTREACHED();
diff --git a/url_request/url_request.cc b/url_request/url_request.cc
index a9f8744..2f8edd5 100644
--- a/url_request/url_request.cc
+++ b/url_request/url_request.cc
@@ -993,13 +993,14 @@
 
 bool URLRequest::GetHSTSRedirect(GURL* redirect_url) const {
   const GURL& url = this->url();
-  if (!url.SchemeIs("http"))
+  bool scheme_is_http = url.SchemeIs("http");
+  if (!scheme_is_http && !url.SchemeIs("ws"))
     return false;
   TransportSecurityState* state = context()->transport_security_state();
   if (state && state->ShouldUpgradeToSSL(url.host())) {
-    url::Replacements<char> replacements;
-    const char kNewScheme[] = "https";
-    replacements.SetScheme(kNewScheme, url::Component(0, strlen(kNewScheme)));
+    GURL::Replacements replacements;
+    const char* new_scheme = scheme_is_http ? "https" : "wss";
+    replacements.SetSchemeStr(new_scheme);
     *redirect_url = url.ReplaceComponents(replacements);
     return true;
   }
diff --git a/url_request/url_request_unittest.cc b/url_request/url_request_unittest.cc
index 5937464..cfe3b71 100644
--- a/url_request/url_request_unittest.cc
+++ b/url_request/url_request_unittest.cc
@@ -7267,6 +7267,26 @@
   EXPECT_EQ(kOriginHeaderValue, received_cors_header);
 }
 
+// This just tests the behaviour of GetHSTSRedirect(). End-to-end tests of HSTS
+// are performed in net/websockets/websocket_end_to_end_test.cc.
+TEST(WebSocketURLRequestTest, HSTSApplied) {
+  TestNetworkDelegate network_delegate;
+  TransportSecurityState transport_security_state;
+  base::Time expiry = base::Time::Now() + base::TimeDelta::FromDays(1);
+  bool include_subdomains = false;
+  transport_security_state.AddHSTS("example.net", expiry, include_subdomains);
+  TestURLRequestContext context(true);
+  context.set_transport_security_state(&transport_security_state);
+  context.set_network_delegate(&network_delegate);
+  context.Init();
+  GURL ws_url("ws://example.net/echo");
+  TestDelegate delegate;
+  scoped_ptr<URLRequest> request(
+      context.CreateRequest(ws_url, DEFAULT_PRIORITY, &delegate, NULL));
+  EXPECT_TRUE(request->GetHSTSRedirect(&ws_url));
+  EXPECT_TRUE(ws_url.SchemeIs("wss"));
+}
+
 namespace {
 
 class SSLClientAuthTestDelegate : public TestDelegate {
diff --git a/websockets/websocket_end_to_end_test.cc b/websockets/websocket_end_to_end_test.cc
index 1a3df04..fc42db9 100644
--- a/websockets/websocket_end_to_end_test.cc
+++ b/websockets/websocket_end_to_end_test.cc
@@ -16,6 +16,7 @@
 #include "base/memory/scoped_ptr.h"
 #include "base/message_loop/message_loop.h"
 #include "base/run_loop.h"
+#include "base/strings/string_piece.h"
 #include "net/base/auth.h"
 #include "net/base/network_delegate.h"
 #include "net/base/test_data_directory.h"
@@ -33,6 +34,13 @@
 
 static const char kEchoServer[] = "echo-with-no-extension";
 
+// Simplify changing URL schemes.
+GURL ReplaceUrlScheme(const GURL& in_url, const base::StringPiece& scheme) {
+  GURL::Replacements replacements;
+  replacements.SetSchemeStr(scheme);
+  return in_url.ReplaceComponents(replacements);
+}
+
 // An implementation of WebSocketEventInterface that waits for and records the
 // results of the connect.
 class ConnectTestingEventInterface : public WebSocketEventInterface {
@@ -216,10 +224,10 @@
 class WebSocketEndToEndTest : public ::testing::Test {
  protected:
   WebSocketEndToEndTest()
-      : event_interface_(new ConnectTestingEventInterface),
+      : event_interface_(),
         network_delegate_(new TestNetworkDelegateWithProxyInfo),
         context_(true),
-        channel_(make_scoped_ptr(event_interface_), &context_),
+        channel_(),
         initialised_context_(false) {}
 
   // Initialise the URLRequestContext. Normally done automatically by
@@ -239,7 +247,10 @@
     }
     std::vector<std::string> sub_protocols;
     url::Origin origin("http://localhost");
-    channel_.SendAddChannelRequest(GURL(socket_url), sub_protocols, origin);
+    event_interface_ = new ConnectTestingEventInterface;
+    channel_.reset(
+        new WebSocketChannel(make_scoped_ptr(event_interface_), &context_));
+    channel_->SendAddChannelRequest(GURL(socket_url), sub_protocols, origin);
     event_interface_->WaitForResponse();
     return !event_interface_->failed();
   }
@@ -247,7 +258,7 @@
   ConnectTestingEventInterface* event_interface_;  // owned by channel_
   scoped_ptr<TestNetworkDelegateWithProxyInfo> network_delegate_;
   TestURLRequestContext context_;
-  WebSocketChannel channel_;
+  scoped_ptr<WebSocketChannel> channel_;
   bool initialised_context_;
 };
 
@@ -338,11 +349,9 @@
   // The test server doesn't have an unauthenticated proxy mode. WebSockets
   // cannot provide auth information that isn't already cached, so it's
   // necessary to preflight an HTTP request to authenticate against the proxy.
-  GURL::Replacements replacements;
-  replacements.SetSchemeStr("http");
   // It doesn't matter what the URL is, as long as it is an HTTP navigation.
   GURL http_page =
-      ws_server.GetURL("connect_check.html").ReplaceComponents(replacements);
+      ReplaceUrlScheme(ws_server.GetURL("connect_check.html"), "http");
   TestDelegate delegate;
   delegate.set_credentials(
       AuthCredentials(base::ASCIIToUTF16("foo"), base::ASCIIToUTF16("bar")));
@@ -377,6 +386,79 @@
   EXPECT_FALSE(ConnectAndWait(ws_url));
 }
 
+// Regression test for crbug.com/455215 "HSTS not applied to WebSocket"
+TEST_F(WebSocketEndToEndTest, DISABLED_ON_ANDROID(HstsHttpsToWebSocket)) {
+  SpawnedTestServer::SSLOptions ssl_options;
+  SpawnedTestServer https_server(
+      SpawnedTestServer::TYPE_HTTPS, ssl_options,
+      base::FilePath(FILE_PATH_LITERAL("net/data/url_request_unittest")));
+  SpawnedTestServer wss_server(SpawnedTestServer::TYPE_WSS, ssl_options,
+                               GetWebSocketTestDataDirectory());
+  ASSERT_TRUE(https_server.StartInBackground());
+  ASSERT_TRUE(wss_server.StartInBackground());
+  ASSERT_TRUE(https_server.BlockUntilStarted());
+  ASSERT_TRUE(wss_server.BlockUntilStarted());
+  InitialiseContext();
+  // Set HSTS via https:
+  TestDelegate delegate;
+  GURL https_page = https_server.GetURL("files/hsts-headers.html");
+  scoped_ptr<URLRequest> request(
+      context_.CreateRequest(https_page, DEFAULT_PRIORITY, &delegate, NULL));
+  request->Start();
+  // TestDelegate exits the message loop when the request completes.
+  base::RunLoop().Run();
+  EXPECT_TRUE(request->status().is_success());
+
+  // Check HSTS with ws:
+  // Change the scheme from wss: to ws: to verify that it is switched back.
+  GURL ws_url = ReplaceUrlScheme(wss_server.GetURL(kEchoServer), "ws");
+  EXPECT_TRUE(ConnectAndWait(ws_url));
+}
+
+TEST_F(WebSocketEndToEndTest, DISABLED_ON_ANDROID(HstsWebSocketToHttps)) {
+  SpawnedTestServer::SSLOptions ssl_options;
+  SpawnedTestServer https_server(
+      SpawnedTestServer::TYPE_HTTPS, ssl_options,
+      base::FilePath(FILE_PATH_LITERAL("net/data/url_request_unittest")));
+  SpawnedTestServer wss_server(SpawnedTestServer::TYPE_WSS, ssl_options,
+                               GetWebSocketTestDataDirectory());
+  ASSERT_TRUE(https_server.StartInBackground());
+  ASSERT_TRUE(wss_server.StartInBackground());
+  ASSERT_TRUE(https_server.BlockUntilStarted());
+  ASSERT_TRUE(wss_server.BlockUntilStarted());
+  InitialiseContext();
+  // Set HSTS via wss:
+  GURL wss_url = wss_server.GetURL("set-hsts");
+  EXPECT_TRUE(ConnectAndWait(wss_url));
+
+  // Verify via http:
+  TestDelegate delegate;
+  GURL http_page =
+      ReplaceUrlScheme(https_server.GetURL("files/simple.html"), "http");
+  scoped_ptr<URLRequest> request(
+      context_.CreateRequest(http_page, DEFAULT_PRIORITY, &delegate, NULL));
+  request->Start();
+  // TestDelegate exits the message loop when the request completes.
+  base::RunLoop().Run();
+  EXPECT_TRUE(request->status().is_success());
+  EXPECT_TRUE(request->url().SchemeIs("https"));
+}
+
+TEST_F(WebSocketEndToEndTest, DISABLED_ON_ANDROID(HstsWebSocketToWebSocket)) {
+  SpawnedTestServer::SSLOptions ssl_options;
+  SpawnedTestServer wss_server(SpawnedTestServer::TYPE_WSS, ssl_options,
+                               GetWebSocketTestDataDirectory());
+  ASSERT_TRUE(wss_server.Start());
+  InitialiseContext();
+  // Set HSTS via wss:
+  GURL wss_url = wss_server.GetURL("set-hsts");
+  EXPECT_TRUE(ConnectAndWait(wss_url));
+
+  // Verify via wss:
+  GURL ws_url = ReplaceUrlScheme(wss_server.GetURL(kEchoServer), "ws");
+  EXPECT_TRUE(ConnectAndWait(ws_url));
+}
+
 }  // namespace
 
 }  // namespace net
diff --git a/websockets/websocket_stream.cc b/websockets/websocket_stream.cc
index b5012bc..b0abb3c 100644
--- a/websockets/websocket_stream.cc
+++ b/websockets/websocket_stream.cc
@@ -58,14 +58,7 @@
   // Implementation of URLRequest::Delegate methods.
   void OnReceivedRedirect(URLRequest* request,
                           const RedirectInfo& redirect_info,
-                          bool* defer_redirect) override {
-    // HTTP status codes returned by HttpStreamParser are filtered by
-    // WebSocketBasicHandshakeStream, and only 101, 401 and 407 are permitted
-    // back up the stack to HttpNetworkTransaction. In particular, redirect
-    // codes are never allowed, and so URLRequest never sees a redirect on a
-    // WebSocket request.
-    NOTREACHED();
-  }
+                          bool* defer_redirect) override;
 
   void OnResponseStarted(URLRequest* request) override;
 
@@ -233,6 +226,31 @@
   URLRequest* url_request_;
 };
 
+void Delegate::OnReceivedRedirect(URLRequest* request,
+                                  const RedirectInfo& redirect_info,
+                                  bool* defer_redirect) {
+  // This code should never be reached for externally generated redirects,
+  // as WebSocketBasicHandshakeStream is responsible for filtering out
+  // all response codes besides 101, 401, and 407. As such, the URLRequest
+  // should never see a redirect sent over the network. However, internal
+  // redirects also result in this method being called, such as those
+  // caused by HSTS.
+  // Because it's security critical to prevent externally-generated
+  // redirects in WebSockets, perform additional checks to ensure this
+  // is only internal.
+  GURL::Replacements replacements;
+  replacements.SetSchemeStr("wss");
+  GURL expected_url = request->original_url().ReplaceComponents(replacements);
+  if (redirect_info.new_method != "GET" ||
+      redirect_info.new_url != expected_url) {
+    // This should not happen.
+    DLOG(FATAL) << "Unauthorized WebSocket redirect to "
+                << redirect_info.new_method << " "
+                << redirect_info.new_url.spec();
+    request->Cancel();
+  }
+}
+
 void Delegate::OnResponseStarted(URLRequest* request) {
   // TODO(vadimt): Remove ScopedTracker below once crbug.com/423948 is fixed.
   tracked_objects::ScopedTracker tracking_profile(