blob: 644679410b598a718d9e5e8570bbbc8c8012e9dd [file] [log] [blame]
// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/websockets/websocket_handshake_stream_create_helper.h"
#include <string>
#include <vector>
#include "net/base/completion_callback.h"
#include "net/base/net_errors.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_request_info.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_response_info.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/socket_test_util.h"
#include "net/websockets/websocket_basic_handshake_stream.h"
#include "net/websockets/websocket_stream.h"
#include "net/websockets/websocket_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/gurl.h"
namespace net {
namespace {
// This class encapsulates the details of creating a mock ClientSocketHandle.
class MockClientSocketHandleFactory {
public:
MockClientSocketHandleFactory()
: histograms_("a"),
pool_(1, 1, &histograms_, socket_factory_maker_.factory()) {}
// The created socket expects |expect_written| to be written to the socket,
// and will respond with |return_to_read|. The test will fail if the expected
// text is not written, or if all the bytes are not read.
scoped_ptr<ClientSocketHandle> CreateClientSocketHandle(
const std::string& expect_written,
const std::string& return_to_read) {
socket_factory_maker_.SetExpectations(expect_written, return_to_read);
scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle);
socket_handle->Init(
"a",
scoped_refptr<MockTransportSocketParams>(),
MEDIUM,
CompletionCallback(),
&pool_,
BoundNetLog());
return socket_handle.Pass();
}
private:
WebSocketDeterministicMockClientSocketFactoryMaker socket_factory_maker_;
ClientSocketPoolHistograms histograms_;
MockTransportClientSocketPool pool_;
DISALLOW_COPY_AND_ASSIGN(MockClientSocketHandleFactory);
};
class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
public:
virtual ~TestConnectDelegate() {}
virtual void OnSuccess(scoped_ptr<WebSocketStream> stream) OVERRIDE {}
virtual void OnFailure(const std::string& failure_message) OVERRIDE {}
virtual void OnStartOpeningHandshake(
scoped_ptr<WebSocketHandshakeRequestInfo> request) OVERRIDE {}
virtual void OnFinishOpeningHandshake(
scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE {}
virtual void OnSSLCertificateError(
scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks>
ssl_error_callbacks,
const SSLInfo& ssl_info,
bool fatal) OVERRIDE {}
};
class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test {
protected:
scoped_ptr<WebSocketStream> CreateAndInitializeStream(
const std::string& socket_url,
const std::string& socket_path,
const std::vector<std::string>& sub_protocols,
const std::string& origin,
const std::string& extra_request_headers,
const std::string& extra_response_headers) {
WebSocketHandshakeStreamCreateHelper create_helper(&connect_delegate_,
sub_protocols);
create_helper.set_failure_message(&failure_message_);
scoped_ptr<ClientSocketHandle> socket_handle =
socket_handle_factory_.CreateClientSocketHandle(
WebSocketStandardRequest(
socket_path, origin, extra_request_headers),
WebSocketStandardResponse(extra_response_headers));
scoped_ptr<WebSocketHandshakeStreamBase> handshake(
create_helper.CreateBasicStream(socket_handle.Pass(), false));
// If in future the implementation type returned by CreateBasicStream()
// changes, this static_cast will be wrong. However, in that case the test
// will fail and AddressSanitizer should identify the issue.
static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
HttpRequestInfo request_info;
request_info.url = GURL(socket_url);
request_info.method = "GET";
request_info.load_flags = LOAD_DISABLE_CACHE | LOAD_DO_NOT_PROMPT_FOR_LOGIN;
int rv = handshake->InitializeStream(
&request_info, DEFAULT_PRIORITY, BoundNetLog(), CompletionCallback());
EXPECT_EQ(OK, rv);
HttpRequestHeaders headers;
headers.SetHeader("Host", "localhost");
headers.SetHeader("Connection", "Upgrade");
headers.SetHeader("Pragma", "no-cache");
headers.SetHeader("Cache-Control", "no-cache");
headers.SetHeader("Upgrade", "websocket");
headers.SetHeader("Origin", origin);
headers.SetHeader("Sec-WebSocket-Version", "13");
headers.SetHeader("User-Agent", "");
headers.SetHeader("Accept-Encoding", "gzip,deflate");
headers.SetHeader("Accept-Language", "en-us,fr");
HttpResponseInfo response;
TestCompletionCallback dummy;
rv = handshake->SendRequest(headers, &response, dummy.callback());
EXPECT_EQ(OK, rv);
rv = handshake->ReadResponseHeaders(dummy.callback());
EXPECT_EQ(OK, rv);
EXPECT_EQ(101, response.headers->response_code());
EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
return handshake->Upgrade();
}
MockClientSocketHandleFactory socket_handle_factory_;
TestConnectDelegate connect_delegate_;
std::string failure_message_;
};
// Confirm that the basic case works as expected.
TEST_F(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
scoped_ptr<WebSocketStream> stream =
CreateAndInitializeStream("ws://localhost/", "/",
std::vector<std::string>(), "http://localhost/",
"", "");
EXPECT_EQ("", stream->GetExtensions());
EXPECT_EQ("", stream->GetSubProtocol());
}
// Verify that the sub-protocols are passed through.
TEST_F(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
std::vector<std::string> sub_protocols;
sub_protocols.push_back("chat");
sub_protocols.push_back("superchat");
scoped_ptr<WebSocketStream> stream =
CreateAndInitializeStream("ws://localhost/",
"/",
sub_protocols,
"http://localhost/",
"Sec-WebSocket-Protocol: chat, superchat\r\n",
"Sec-WebSocket-Protocol: superchat\r\n");
EXPECT_EQ("superchat", stream->GetSubProtocol());
}
// Verify that extension name is available. Bad extension names are tested in
// websocket_stream_test.cc.
TEST_F(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream(
"ws://localhost/",
"/",
std::vector<std::string>(),
"http://localhost/",
"",
"Sec-WebSocket-Extensions: permessage-deflate\r\n");
EXPECT_EQ("permessage-deflate", stream->GetExtensions());
}
// Verify that extension parameters are available. Bad parameters are tested in
// websocket_stream_test.cc.
TEST_F(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream(
"ws://localhost/",
"/",
std::vector<std::string>(),
"http://localhost/",
"",
"Sec-WebSocket-Extensions: permessage-deflate;"
" client_max_window_bits=14; server_max_window_bits=14;"
" server_no_context_takeover; client_no_context_takeover\r\n");
EXPECT_EQ(
"permessage-deflate;"
" client_max_window_bits=14; server_max_window_bits=14;"
" server_no_context_takeover; client_no_context_takeover",
stream->GetExtensions());
}
} // namespace
} // namespace net