blob: 8a492e737713454001826a7ea5052c7a6efd7a4d [file] [log] [blame]
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/websockets/websocket.h"
#include <string>
#include <vector>
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/callback.h"
#include "net/base/completion_callback.h"
#include "net/base/io_buffer.h"
#include "net/base/mock_host_resolver.h"
#include "net/base/test_completion_callback.h"
#include "net/socket/socket_test_util.h"
#include "net/url_request/url_request_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/platform_test.h"
struct WebSocketEvent {
enum EventType {
EVENT_OPEN, EVENT_MESSAGE, EVENT_ERROR, EVENT_CLOSE,
};
WebSocketEvent(EventType type, net::WebSocket* websocket,
const std::string& websocket_msg, bool websocket_flag)
: event_type(type), socket(websocket), msg(websocket_msg),
flag(websocket_flag) {}
EventType event_type;
net::WebSocket* socket;
std::string msg;
bool flag;
};
class WebSocketEventRecorder : public net::WebSocketDelegate {
public:
explicit WebSocketEventRecorder(net::CompletionCallback* callback)
: callback_(callback) {}
virtual ~WebSocketEventRecorder() {}
void SetOnOpen(const base::Callback<void(WebSocketEvent*)>& callback) {
onopen_ = callback;
}
void SetOnMessage(const base::Callback<void(WebSocketEvent*)>& callback) {
onmessage_ = callback;
}
void SetOnClose(const base::Callback<void(WebSocketEvent*)>& callback) {
onclose_ = callback;
}
virtual void OnOpen(net::WebSocket* socket) {
events_.push_back(
WebSocketEvent(WebSocketEvent::EVENT_OPEN, socket,
std::string(), false));
if (!onopen_.is_null())
onopen_.Run(&events_.back());
}
virtual void OnMessage(net::WebSocket* socket, const std::string& msg) {
events_.push_back(
WebSocketEvent(WebSocketEvent::EVENT_MESSAGE, socket, msg, false));
if (!onmessage_.is_null())
onmessage_.Run(&events_.back());
}
virtual void OnError(net::WebSocket* socket) {
events_.push_back(
WebSocketEvent(WebSocketEvent::EVENT_ERROR, socket,
std::string(), false));
if (!onerror_.is_null())
onerror_.Run(&events_.back());
}
virtual void OnClose(net::WebSocket* socket, bool was_clean) {
events_.push_back(
WebSocketEvent(WebSocketEvent::EVENT_CLOSE, socket,
std::string(), was_clean));
if (!onclose_.is_null())
onclose_.Run(&events_.back());
if (callback_)
callback_->Run(net::OK);
}
void DoClose(WebSocketEvent* event) {
event->socket->Close();
}
const std::vector<WebSocketEvent>& GetSeenEvents() const {
return events_;
}
private:
std::vector<WebSocketEvent> events_;
base::Callback<void(WebSocketEvent*)> onopen_;
base::Callback<void(WebSocketEvent*)> onmessage_;
base::Callback<void(WebSocketEvent*)> onerror_;
base::Callback<void(WebSocketEvent*)> onclose_;
net::CompletionCallback* callback_;
DISALLOW_COPY_AND_ASSIGN(WebSocketEventRecorder);
};
namespace net {
class WebSocketTest : public PlatformTest {
protected:
void InitReadBuf(WebSocket* websocket) {
// Set up |current_read_buf_|.
websocket->current_read_buf_ = new GrowableIOBuffer();
}
void SetReadConsumed(WebSocket* websocket, int consumed) {
websocket->read_consumed_len_ = consumed;
}
void AddToReadBuf(WebSocket* websocket, const char* data, int len) {
websocket->AddToReadBuffer(data, len);
}
void TestProcessFrameData(WebSocket* websocket,
const char* expected_remaining_data,
int expected_remaining_len) {
websocket->ProcessFrameData();
const char* actual_remaining_data =
websocket->current_read_buf_->StartOfBuffer()
+ websocket->read_consumed_len_;
int actual_remaining_len =
websocket->current_read_buf_->offset() - websocket->read_consumed_len_;
EXPECT_EQ(expected_remaining_len, actual_remaining_len);
EXPECT_TRUE(!memcmp(expected_remaining_data, actual_remaining_data,
expected_remaining_len));
}
};
TEST_F(WebSocketTest, Connect) {
MockClientSocketFactory mock_socket_factory;
MockRead data_reads[] = {
MockRead("HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
"Upgrade: WebSocket\r\n"
"Connection: Upgrade\r\n"
"WebSocket-Origin: http://example.com\r\n"
"WebSocket-Location: ws://example.com/demo\r\n"
"WebSocket-Protocol: sample\r\n"
"\r\n"),
// Server doesn't close the connection after handshake.
MockRead(true, ERR_IO_PENDING),
};
MockWrite data_writes[] = {
MockWrite("GET /demo HTTP/1.1\r\n"
"Upgrade: WebSocket\r\n"
"Connection: Upgrade\r\n"
"Host: example.com\r\n"
"Origin: http://example.com\r\n"
"WebSocket-Protocol: sample\r\n"
"\r\n"),
};
StaticSocketDataProvider data(data_reads, arraysize(data_reads),
data_writes, arraysize(data_writes));
mock_socket_factory.AddSocketDataProvider(&data);
MockHostResolver host_resolver;
WebSocket::Request* request(
new WebSocket::Request(GURL("ws://example.com/demo"),
"sample",
"http://example.com",
"ws://example.com/demo",
WebSocket::DRAFT75,
new TestURLRequestContext()));
request->SetHostResolver(&host_resolver);
request->SetClientSocketFactory(&mock_socket_factory);
TestCompletionCallback callback;
scoped_ptr<WebSocketEventRecorder> delegate(
new WebSocketEventRecorder(&callback));
delegate->SetOnOpen(base::Bind(&WebSocketEventRecorder::DoClose,
base::Unretained(delegate.get())));
scoped_refptr<WebSocket> websocket(
new WebSocket(request, delegate.get()));
EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state());
websocket->Connect();
callback.WaitForResult();
const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
EXPECT_EQ(2U, events.size());
EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type);
EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[1].event_type);
}
TEST_F(WebSocketTest, ServerSentData) {
MockClientSocketFactory mock_socket_factory;
static const char kMessage[] = "Hello";
static const char kFrame[] = "\x00Hello\xff";
static const int kFrameLen = sizeof(kFrame) - 1;
MockRead data_reads[] = {
MockRead("HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
"Upgrade: WebSocket\r\n"
"Connection: Upgrade\r\n"
"WebSocket-Origin: http://example.com\r\n"
"WebSocket-Location: ws://example.com/demo\r\n"
"WebSocket-Protocol: sample\r\n"
"\r\n"),
MockRead(true, kFrame, kFrameLen),
// Server doesn't close the connection after handshake.
MockRead(true, ERR_IO_PENDING),
};
MockWrite data_writes[] = {
MockWrite("GET /demo HTTP/1.1\r\n"
"Upgrade: WebSocket\r\n"
"Connection: Upgrade\r\n"
"Host: example.com\r\n"
"Origin: http://example.com\r\n"
"WebSocket-Protocol: sample\r\n"
"\r\n"),
};
StaticSocketDataProvider data(data_reads, arraysize(data_reads),
data_writes, arraysize(data_writes));
mock_socket_factory.AddSocketDataProvider(&data);
MockHostResolver host_resolver;
WebSocket::Request* request(
new WebSocket::Request(GURL("ws://example.com/demo"),
"sample",
"http://example.com",
"ws://example.com/demo",
WebSocket::DRAFT75,
new TestURLRequestContext()));
request->SetHostResolver(&host_resolver);
request->SetClientSocketFactory(&mock_socket_factory);
TestCompletionCallback callback;
scoped_ptr<WebSocketEventRecorder> delegate(
new WebSocketEventRecorder(&callback));
delegate->SetOnMessage(base::Bind(&WebSocketEventRecorder::DoClose,
base::Unretained(delegate.get())));
scoped_refptr<WebSocket> websocket(
new WebSocket(request, delegate.get()));
EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state());
websocket->Connect();
callback.WaitForResult();
const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
EXPECT_EQ(3U, events.size());
EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type);
EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[1].event_type);
EXPECT_EQ(kMessage, events[1].msg);
EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[2].event_type);
}
TEST_F(WebSocketTest, ProcessFrameDataForLengthCalculation) {
WebSocket::Request* request(
new WebSocket::Request(GURL("ws://example.com/demo"),
"sample",
"http://example.com",
"ws://example.com/demo",
WebSocket::DRAFT75,
new TestURLRequestContext()));
TestCompletionCallback callback;
scoped_ptr<WebSocketEventRecorder> delegate(
new WebSocketEventRecorder(&callback));
scoped_refptr<WebSocket> websocket(
new WebSocket(request, delegate.get()));
// Frame data: skip length 1 ('x'), and try to skip length 129
// (1 * 128 + 1) bytes after \x81\x01, but buffer is too short to skip.
static const char kTestLengthFrame[] =
"\x80\x01x\x80\x81\x01\x01\x00unexpected data\xFF";
const int kTestLengthFrameLength = sizeof(kTestLengthFrame) - 1;
InitReadBuf(websocket.get());
AddToReadBuf(websocket.get(), kTestLengthFrame, kTestLengthFrameLength);
SetReadConsumed(websocket.get(), 0);
static const char kExpectedRemainingFrame[] =
"\x80\x81\x01\x01\x00unexpected data\xFF";
const int kExpectedRemainingLength = sizeof(kExpectedRemainingFrame) - 1;
TestProcessFrameData(websocket.get(),
kExpectedRemainingFrame, kExpectedRemainingLength);
// No onmessage event expected.
const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
EXPECT_EQ(1U, events.size());
EXPECT_EQ(WebSocketEvent::EVENT_ERROR, events[0].event_type);
websocket->DetachDelegate();
}
TEST_F(WebSocketTest, ProcessFrameDataForUnterminatedString) {
WebSocket::Request* request(
new WebSocket::Request(GURL("ws://example.com/demo"),
"sample",
"http://example.com",
"ws://example.com/demo",
WebSocket::DRAFT75,
new TestURLRequestContext()));
TestCompletionCallback callback;
scoped_ptr<WebSocketEventRecorder> delegate(
new WebSocketEventRecorder(&callback));
scoped_refptr<WebSocket> websocket(
new WebSocket(request, delegate.get()));
static const char kTestUnterminatedFrame[] =
"\x00unterminated frame";
const int kTestUnterminatedFrameLength = sizeof(kTestUnterminatedFrame) - 1;
InitReadBuf(websocket.get());
AddToReadBuf(websocket.get(), kTestUnterminatedFrame,
kTestUnterminatedFrameLength);
SetReadConsumed(websocket.get(), 0);
TestProcessFrameData(websocket.get(),
kTestUnterminatedFrame, kTestUnterminatedFrameLength);
{
// No onmessage event expected.
const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
EXPECT_EQ(0U, events.size());
}
static const char kTestTerminateFrame[] = " is terminated in next read\xff";
const int kTestTerminateFrameLength = sizeof(kTestTerminateFrame) - 1;
AddToReadBuf(websocket.get(), kTestTerminateFrame,
kTestTerminateFrameLength);
TestProcessFrameData(websocket.get(), "", 0);
static const char kExpectedMsg[] =
"unterminated frame is terminated in next read";
{
const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
EXPECT_EQ(1U, events.size());
EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[0].event_type);
EXPECT_EQ(kExpectedMsg, events[0].msg);
}
websocket->DetachDelegate();
}
} // namespace net