blob: 1b3fb379ea2333c5f3cf1b7b6d7244f425ef2eb1 [file] [log] [blame]
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/cast_channel/cast_transport.h"
#include <stddef.h>
#include <stdint.h>
#include <utility>
#include "base/bind.h"
#include "base/callback_helpers.h"
#include "base/containers/queue.h"
#include "base/macros.h"
#include "base/memory/ptr_util.h"
#include "base/run_loop.h"
#include "base/test/scoped_task_environment.h"
#include "base/test/simple_test_clock.h"
#include "components/cast_channel/cast_framer.h"
#include "components/cast_channel/cast_socket.h"
#include "components/cast_channel/cast_test_util.h"
#include "components/cast_channel/logger.h"
#include "components/cast_channel/proto/cast_channel.pb.h"
#include "net/base/completion_once_callback.h"
#include "net/base/completion_repeating_callback.h"
#include "net/base/net_errors.h"
#include "net/log/test_net_log.h"
#include "net/socket/socket.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "services/network/network_context.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using testing::_;
using testing::DoAll;
using testing::InSequence;
using testing::Invoke;
using testing::NotNull;
using testing::Return;
using testing::WithArg;
namespace cast_channel {
namespace {
const int kChannelId = 0;
// Mockable placeholder for write completion events.
class CompleteHandler {
public:
CompleteHandler() {}
MOCK_METHOD1(Complete, void(int result));
private:
DISALLOW_COPY_AND_ASSIGN(CompleteHandler);
};
// Creates a CastMessage proto with the bare minimum required fields set.
CastMessage CreateCastMessage() {
CastMessage output;
output.set_protocol_version(CastMessage::CASTV2_1_0);
output.set_namespace_("x");
output.set_source_id("source");
output.set_destination_id("destination");
output.set_payload_type(CastMessage::STRING);
output.set_payload_utf8("payload");
return output;
}
// FIFO queue of completion callbacks. Outstanding write operations are
// Push()ed into the queue. Callback completion is simulated by invoking
// Pop() in the same order as Push().
class CompletionQueue {
public:
CompletionQueue() {}
~CompletionQueue() { CHECK_EQ(0u, cb_queue_.size()); }
// Enqueues a pending completion callback.
void Push(net::CompletionOnceCallback cb) { cb_queue_.push(std::move(cb)); }
// Runs the next callback and removes it from the queue.
void Pop(int rv) {
CHECK_GT(cb_queue_.size(), 0u);
std::move(cb_queue_.front()).Run(rv);
cb_queue_.pop();
}
private:
base::queue<net::CompletionOnceCallback> cb_queue_;
DISALLOW_COPY_AND_ASSIGN(CompletionQueue);
};
// GMock action that reads data from an IOBuffer and writes it to a string
// variable.
//
// buf_idx (template parameter 0): 0-based index of the net::IOBuffer
// in the function mock arg list.
// size_idx (template parameter 1): 0-based index of the byte count arg.
// str: pointer to the string which will receive data from the buffer.
ACTION_TEMPLATE(ReadBufferToString,
HAS_2_TEMPLATE_PARAMS(int, buf_idx, int, size_idx),
AND_1_VALUE_PARAMS(str)) {
str->assign(testing::get<buf_idx>(args)->data(),
testing::get<size_idx>(args));
}
// GMock action that writes data from a string to an IOBuffer.
//
// buf_idx (template parameter 0): 0-based index of the IOBuffer arg.
// str: the string containing data to be written to the IOBuffer.
ACTION_TEMPLATE(FillBufferFromString,
HAS_1_TEMPLATE_PARAMS(int, buf_idx),
AND_1_VALUE_PARAMS(str)) {
memcpy(testing::get<buf_idx>(args)->data(), str.data(), str.size());
}
// GMock action that enqueues a write completion callback in a queue.
//
// buf_idx (template parameter 0): 0-based index of the CompletionCallback.
// completion_queue: a pointer to the CompletionQueue.
ACTION_TEMPLATE(EnqueueCallback,
HAS_1_TEMPLATE_PARAMS(int, cb_idx),
AND_1_VALUE_PARAMS(completion_queue)) {
completion_queue->Push(testing::get<cb_idx>(args));
}
} // namespace
class MockSocket : public cast_channel::CastTransportImpl::Channel {
public:
void Read(net::IOBuffer* buffer,
int bytes,
net::CompletionOnceCallback callback) override {
Read(buffer, bytes, base::AdaptCallbackForRepeating(std::move(callback)));
}
void Write(net::IOBuffer* buffer,
int bytes,
net::CompletionOnceCallback callback) override {
Write(buffer, bytes, base::AdaptCallbackForRepeating(std::move(callback)));
}
MOCK_METHOD3(Read,
void(net::IOBuffer* buf,
int buf_len,
const net::CompletionRepeatingCallback& callback));
MOCK_METHOD3(Write,
void(net::IOBuffer* buf,
int buf_len,
const net::CompletionRepeatingCallback& callback));
};
class CastTransportTest : public testing::Test {
public:
CastTransportTest() : logger_(new Logger()) {
delegate_ = new MockCastTransportDelegate;
transport_.reset(new CastTransportImpl(&mock_socket_, kChannelId,
CreateIPEndPointForTest(), logger_));
transport_->SetReadDelegate(base::WrapUnique(delegate_));
}
~CastTransportTest() override {}
protected:
// Runs all pending tasks in the message loop.
void RunPendingTasks() {
base::RunLoop run_loop;
run_loop.RunUntilIdle();
}
base::test::ScopedTaskEnvironment task_environment_;
MockCastTransportDelegate* delegate_;
MockSocket mock_socket_;
Logger* logger_;
std::unique_ptr<CastTransport> transport_;
};
// ----------------------------------------------------------------------------
// Asynchronous write tests
TEST_F(CastTransportTest, TestFullWriteAsync) {
CompletionQueue socket_cbs;
CompleteHandler write_handler;
std::string output;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
EXPECT_CALL(mock_socket_, Write(NotNull(), serialized_message.size(), _))
.WillOnce(DoAll(ReadBufferToString<0, 1>(&output),
EnqueueCallback<2>(&socket_cbs)));
EXPECT_CALL(write_handler, Complete(net::OK));
transport_->SendMessage(message,
base::BindOnce(&CompleteHandler::Complete,
base::Unretained(&write_handler)));
RunPendingTasks();
socket_cbs.Pop(serialized_message.size());
RunPendingTasks();
EXPECT_EQ(serialized_message, output);
}
TEST_F(CastTransportTest, TestPartialWritesAsync) {
InSequence seq;
CompletionQueue socket_cbs;
CompleteHandler write_handler;
std::string output;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Only one byte is written.
EXPECT_CALL(mock_socket_,
Write(NotNull(), static_cast<int>(serialized_message.size()), _))
.WillOnce(DoAll(ReadBufferToString<0, 1>(&output),
EnqueueCallback<2>(&socket_cbs)));
// Remainder of bytes are written.
EXPECT_CALL(
mock_socket_,
Write(NotNull(), static_cast<int>(serialized_message.size() - 1), _))
.WillOnce(DoAll(ReadBufferToString<0, 1>(&output),
EnqueueCallback<2>(&socket_cbs)));
transport_->SendMessage(message,
base::BindOnce(&CompleteHandler::Complete,
base::Unretained(&write_handler)));
RunPendingTasks();
EXPECT_EQ(serialized_message, output);
socket_cbs.Pop(1);
RunPendingTasks();
EXPECT_CALL(write_handler, Complete(net::OK));
socket_cbs.Pop(serialized_message.size() - 1);
RunPendingTasks();
EXPECT_EQ(serialized_message.substr(1, serialized_message.size() - 1),
output);
}
TEST_F(CastTransportTest, TestWriteFailureAsync) {
CompletionQueue socket_cbs;
CompleteHandler write_handler;
CastMessage message = CreateCastMessage();
EXPECT_CALL(mock_socket_, Write(NotNull(), _, _))
.WillOnce(EnqueueCallback<2>(&socket_cbs));
EXPECT_CALL(write_handler, Complete(net::ERR_FAILED));
EXPECT_CALL(*delegate_, OnError(ChannelError::CAST_SOCKET_ERROR));
transport_->SendMessage(message,
base::BindOnce(&CompleteHandler::Complete,
base::Unretained(&write_handler)));
RunPendingTasks();
socket_cbs.Pop(net::ERR_CONNECTION_RESET);
RunPendingTasks();
EXPECT_EQ(ChannelEvent::SOCKET_WRITE,
logger_->GetLastError(kChannelId).channel_event);
EXPECT_EQ(net::ERR_CONNECTION_RESET,
logger_->GetLastError(kChannelId).net_return_value);
}
// ----------------------------------------------------------------------------
// Asynchronous read tests
TEST_F(CastTransportTest, TestFullReadAsync) {
InSequence s;
CompletionQueue socket_cbs;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
EXPECT_CALL(*delegate_, Start());
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
EnqueueCallback<2>(&socket_cbs)));
// Read bytes [4, n].
EXPECT_CALL(mock_socket_,
Read(NotNull(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size(),
_))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
MessageFramer::MessageHeader::header_size(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size())),
EnqueueCallback<2>(&socket_cbs)))
.RetiresOnSaturation();
EXPECT_CALL(*delegate_, OnMessage(EqualsProto(message)));
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _));
transport_->Start();
RunPendingTasks();
socket_cbs.Pop(MessageFramer::MessageHeader::header_size());
socket_cbs.Pop(serialized_message.size() -
MessageFramer::MessageHeader::header_size());
RunPendingTasks();
}
TEST_F(CastTransportTest, TestPartialReadAsync) {
InSequence s;
CompletionQueue socket_cbs;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
EXPECT_CALL(*delegate_, Start());
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
EnqueueCallback<2>(&socket_cbs)))
.RetiresOnSaturation();
// Read bytes [4, n-1].
EXPECT_CALL(mock_socket_,
Read(NotNull(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size(),
_))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
MessageFramer::MessageHeader::header_size(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size() - 1)),
EnqueueCallback<2>(&socket_cbs)))
.RetiresOnSaturation();
// Read final byte.
EXPECT_CALL(mock_socket_, Read(NotNull(), 1, _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
serialized_message.size() - 1, 1)),
EnqueueCallback<2>(&socket_cbs)))
.RetiresOnSaturation();
EXPECT_CALL(*delegate_, OnMessage(EqualsProto(message)));
transport_->Start();
socket_cbs.Pop(MessageFramer::MessageHeader::header_size());
socket_cbs.Pop(serialized_message.size() -
MessageFramer::MessageHeader::header_size() - 1);
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _));
socket_cbs.Pop(1);
}
TEST_F(CastTransportTest, TestReadErrorInHeaderAsync) {
CompletionQueue socket_cbs;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
EXPECT_CALL(*delegate_, Start());
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
EnqueueCallback<2>(&socket_cbs)))
.RetiresOnSaturation();
EXPECT_CALL(*delegate_, OnError(ChannelError::CAST_SOCKET_ERROR));
transport_->Start();
// Header read failure.
socket_cbs.Pop(net::ERR_CONNECTION_RESET);
EXPECT_EQ(ChannelEvent::SOCKET_READ,
logger_->GetLastError(kChannelId).channel_event);
EXPECT_EQ(net::ERR_CONNECTION_RESET,
logger_->GetLastError(kChannelId).net_return_value);
}
TEST_F(CastTransportTest, TestReadErrorInBodyAsync) {
CompletionQueue socket_cbs;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
EXPECT_CALL(*delegate_, Start());
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
EnqueueCallback<2>(&socket_cbs)))
.RetiresOnSaturation();
// Read bytes [4, n-1].
EXPECT_CALL(mock_socket_,
Read(NotNull(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size(),
_))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
MessageFramer::MessageHeader::header_size(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size() - 1)),
EnqueueCallback<2>(&socket_cbs)))
.RetiresOnSaturation();
EXPECT_CALL(*delegate_, OnError(ChannelError::CAST_SOCKET_ERROR));
transport_->Start();
// Header read is OK.
socket_cbs.Pop(MessageFramer::MessageHeader::header_size());
// Body read fails.
socket_cbs.Pop(net::ERR_CONNECTION_RESET);
EXPECT_EQ(ChannelEvent::SOCKET_READ,
logger_->GetLastError(kChannelId).channel_event);
EXPECT_EQ(net::ERR_CONNECTION_RESET,
logger_->GetLastError(kChannelId).net_return_value);
}
TEST_F(CastTransportTest, TestReadCorruptedMessageAsync) {
CompletionQueue socket_cbs;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Corrupt the serialized message body(set it to X's).
for (size_t i = MessageFramer::MessageHeader::header_size();
i < serialized_message.size(); ++i) {
serialized_message[i] = 'x';
}
EXPECT_CALL(*delegate_, Start());
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
EnqueueCallback<2>(&socket_cbs)))
.RetiresOnSaturation();
// Read bytes [4, n].
EXPECT_CALL(mock_socket_,
Read(NotNull(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size(),
_))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
MessageFramer::MessageHeader::header_size(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size() - 1)),
EnqueueCallback<2>(&socket_cbs)))
.RetiresOnSaturation();
EXPECT_CALL(*delegate_, OnError(ChannelError::INVALID_MESSAGE));
transport_->Start();
socket_cbs.Pop(MessageFramer::MessageHeader::header_size());
socket_cbs.Pop(serialized_message.size() -
MessageFramer::MessageHeader::header_size());
}
} // namespace cast_channel